
前面介绍的都是无状态的单词请求,如果希望连续聊天,并且AI能根据历史的聊天记录给出相关联的回答,怎么做呢?看下面的例子:
memoryBuffer := memory.NewConversationWindowBuffer(10)
chatChain := chains.NewConversation(llm, memoryBuffer)
messages := []string{
"你好,我叫PBR",
"你知道我叫什么吗?",
"你可以解决什么问题?",
}
for _, message := range messages {
completion, err := chains.Run(ctx, chatChain, message)
for {
if err == nil {
break
}
time.Sleep(30 * time.Second)
completion, err = chains.Run(ctx, chatChain, message)
}
chatMessages, _ := memoryBuffer.ChatHistory.Messages(ctx)
fmt.Printf("上下文对话历史:%v\n", chatMessages)
fmt.Printf("输入:%v\n输出:%v\n", message, completion)可以看到,我们定义了NewConversationWindowBuffer:带聊天记录条数窗口的buffer,当然也可以根据实际情况定义更多类型的buffer:简单buffer或者指定token长度的buffer
memoryBuffer := memory.NewConversationBuffer()
memoryBuffer := memory.NewConversationWindowBuffer(10)
memoryBuffer := memory.NewConversationTokenBuffer(llm, 1024) github.com/tmc/langchaingo@v0.1.13/memory/window_buffer.go
func NewConversationWindowBuffer(
conversationWindowSize int,
options ...ConversationBufferOption,
) *ConversationWindowBuffer {
if conversationWindowSize <= 0 {
conversationWindowSize = defaultConversationWindowSize
}
tb := &ConversationWindowBuffer{
ConversationWindowSize: conversationWindowSize,
ConversationBuffer: *applyBufferOptions(options...),
}
return tb
}type ConversationWindowBuffer struct {
ConversationBuffer
ConversationWindowSize int
}继承自github.com/tmc/langchaingo@v0.1.13/memory/buffer.go
type ConversationBuffer struct {
ChatHistory schema.ChatMessageHistory
ReturnMessages bool
InputKey string
OutputKey string
HumanPrefix string
AIPrefix string
MemoryKey string
}其中聊天历史的定义是一个接口,支持添加普通消息、用户消息还有AI响应消息:
github.com/tmc/langchaingo@v0.1.13/schema/chat_message_history.go
// ChatMessageHistory is the interface for chat history in memory/store.
type ChatMessageHistory interface {
// AddMessage adds a message to the store.
AddMessage(ctx context.Context, message llms.ChatMessage) error
// AddUserMessage is a convenience method for adding a human message string
// to the store.
AddUserMessage(ctx context.Context, message string) error
// AddAIMessage is a convenience method for adding an AI message string to
// the store.
AddAIMessage(ctx context.Context, message string) error
// Clear removes all messages from the store.
Clear(ctx context.Context) error
// Messages retrieves all messages from the store
Messages(ctx context.Context) ([]llms.ChatMessage, error)
// SetMessages replaces existing messages in the store
SetMessages(ctx context.Context, messages []llms.ChatMessage) error
}另外两个的定义如下:
func applyBufferOptions(opts ...ConversationBufferOption) *ConversationBuffer {
m := &ConversationBuffer{
ReturnMessages: false,
InputKey: "",
OutputKey: "",
HumanPrefix: "Human",
AIPrefix: "AI",
MemoryKey: "history",
}
for _, opt := range opts {
opt(m)
}
if m.ChatHistory == nil {
m.ChatHistory = NewChatMessageHistory()
}
return m
}func NewConversationTokenBuffer(
llm llms.Model,
maxTokenLimit int,
options ...ConversationBufferOption,
) *ConversationTokenBuffer {
tb := &ConversationTokenBuffer{
LLM: llm,
MaxTokenLimit: maxTokenLimit,
ConversationBuffer: *applyBufferOptions(options...),
}
return tb
}区别是后者添加了token数限制。
定义完记忆缓冲区后,接着就是初始化聊天会话
func NewConversation(llm llms.Model, memory schema.Memory) LLMChain {
return LLMChain{
Prompt: prompts.NewPromptTemplate(
_conversationTemplate,
[]string{"history", "input"},
),
LLM: llm,
Memory: memory,
OutputParser: outputparser.NewSimple(),
OutputKey: _llmChainDefaultOutputKey,
}
}const _llmChainDefaultOutputKey = "text"可以看到,它定义了一个提示词模板,模板里面有history和input两个参数,具体模板内容如下:
const _conversationTemplate = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{{.history}}
Human: {{.input}}
AI:`而LLMChain的定义如下:
type LLMChain struct {
Prompt prompts.FormatPrompter
LLM llms.Model
Memory schema.Memory
CallbacksHandler callbacks.Handler
OutputParser schema.OutputParser[any]
OutputKey string
}最后通过llmchain的Run方法获取返回结果:
func Run(ctx context.Context, c Chain, input any, options ...ChainCallOption) (string, error) {
inputKeys := c.GetInputKeys()
memoryKeys := c.GetMemory().MemoryVariables(ctx)
neededKeys := make([]string, 0, len(inputKeys))
// Remove keys gotten from the memory.
for _, inputKey := range inputKeys {
isInMemory := false
for _, memoryKey := range memoryKeys {
if inputKey == memoryKey {
isInMemory = true
continue
}
}
if isInMemory {
continue
}
neededKeys = append(neededKeys, inputKey)
}
if len(neededKeys) != 1 {
return "", ErrMultipleInputsInRun
}
outputKeys := c.GetOutputKeys()
if len(outputKeys) != 1 {
return "", ErrMultipleOutputsInRun
}
inputValues := map[string]any{neededKeys[0]: input}
outputValues, err := Call(ctx, c, inputValues, options...)
if err != nil {
return "", err
}
outputValue, ok := outputValues[outputKeys[0]].(string)
if !ok {
return "", ErrWrongOutputTypeInRun
}
return outputValue, nil
}在Run内部调用了Call方法,它先从memory里面取历史记录,然后根据选项筛选历史记录,接着调用LLM获取返回结果,并将结果记录到memory里面给下次使用
// Call is the standard function used for executing chains.
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
fullValues := make(map[string]any, 0)
for key, value := range inputValues {
fullValues[key] = value
}
newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
if err != nil {
return nil, err
}
for key, value := range newValues {
fullValues[key] = value
}
callbacksHandler := getChainCallbackHandler(c)
if callbacksHandler != nil {
callbacksHandler.HandleChainStart(ctx, inputValues)
}
outputValues, err := callChain(ctx, c, fullValues, options...)
if err != nil {
if callbacksHandler != nil {
callbacksHandler.HandleChainError(ctx, err)
}
return outputValues, err
}
if callbacksHandler != nil {
callbacksHandler.HandleChainEnd(ctx, outputValues)
}
if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
return outputValues, err
}
return outputValues, nil
}其中的callChain就是调用chain中的每个元素的call方法
func callChain(
ctx context.Context,
c Chain,
fullValues map[string]any,
options ...ChainCallOption,
) (map[string]any, error) {
if err := validateInputs(c, fullValues); err != nil {
return nil, err
}
outputValues, err := c.Call(ctx, fullValues, options...)
if err != nil {
return outputValues, err
}
if err := validateOutputs(c, outputValues); err != nil {
return outputValues, err
}
return outputValues, nil
}至此历史记录的相关源码介绍完毕。
本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!