首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >golang源码分析:langchaingo(5)

golang源码分析:langchaingo(5)

作者头像
golangLeetcode
发布2026-03-18 17:56:25
发布2026-03-18 17:56:25
850
举报

前面介绍的都是无状态的单词请求,如果希望连续聊天,并且AI能根据历史的聊天记录给出相关联的回答,怎么做呢?看下面的例子:

代码语言:javascript
复制
    memoryBuffer := memory.NewConversationWindowBuffer(10)
    chatChain := chains.NewConversation(llm, memoryBuffer)
    messages := []string{
        "你好,我叫PBR",
        "你知道我叫什么吗?",
        "你可以解决什么问题?",
    }
    for _, message := range messages {
        completion, err := chains.Run(ctx, chatChain, message)
        for {
            if err == nil {
                break
            }
            time.Sleep(30 * time.Second)
            completion, err = chains.Run(ctx, chatChain, message)
        }
        chatMessages, _ := memoryBuffer.ChatHistory.Messages(ctx)
        fmt.Printf("上下文对话历史:%v\n", chatMessages)
        fmt.Printf("输入:%v\n输出:%v\n", message, completion)

可以看到,我们定义了NewConversationWindowBuffer:带聊天记录条数窗口的buffer,当然也可以根据实际情况定义更多类型的buffer:简单buffer或者指定token长度的buffer

代码语言:javascript
复制
memoryBuffer := memory.NewConversationBuffer()
memoryBuffer := memory.NewConversationWindowBuffer(10)
memoryBuffer := memory.NewConversationTokenBuffer(llm, 1024)  

github.com/tmc/langchaingo@v0.1.13/memory/window_buffer.go

代码语言:javascript
复制
func NewConversationWindowBuffer(
    conversationWindowSize int,
    options ...ConversationBufferOption,
) *ConversationWindowBuffer {
    if conversationWindowSize <= 0 {
        conversationWindowSize = defaultConversationWindowSize
    }
    tb := &ConversationWindowBuffer{
        ConversationWindowSize: conversationWindowSize,
        ConversationBuffer:     *applyBufferOptions(options...),
    }
    return tb
}
代码语言:javascript
复制
type ConversationWindowBuffer struct {
    ConversationBuffer
    ConversationWindowSize int
}

继承自github.com/tmc/langchaingo@v0.1.13/memory/buffer.go

代码语言:javascript
复制
type ConversationBuffer struct {
    ChatHistory schema.ChatMessageHistory
    ReturnMessages bool
    InputKey       string
    OutputKey      string
    HumanPrefix    string
    AIPrefix       string
    MemoryKey      string
}

其中聊天历史的定义是一个接口,支持添加普通消息、用户消息还有AI响应消息:

github.com/tmc/langchaingo@v0.1.13/schema/chat_message_history.go

代码语言:javascript
复制
// ChatMessageHistory is the interface for chat history in memory/store.
type ChatMessageHistory interface {
    // AddMessage adds a message to the store.
    AddMessage(ctx context.Context, message llms.ChatMessage) error
    // AddUserMessage is a convenience method for adding a human message string
    // to the store.
    AddUserMessage(ctx context.Context, message string) error
    // AddAIMessage is a convenience method for adding an AI message string to
    // the store.
    AddAIMessage(ctx context.Context, message string) error
    // Clear removes all messages from the store.
    Clear(ctx context.Context) error
    // Messages retrieves all messages from the store
    Messages(ctx context.Context) ([]llms.ChatMessage, error)
    // SetMessages replaces existing messages in the store
    SetMessages(ctx context.Context, messages []llms.ChatMessage) error
}

另外两个的定义如下:

代码语言:javascript
复制
func applyBufferOptions(opts ...ConversationBufferOption) *ConversationBuffer {
    m := &ConversationBuffer{
        ReturnMessages: false,
        InputKey:       "",
        OutputKey:      "",
        HumanPrefix:    "Human",
        AIPrefix:       "AI",
        MemoryKey:      "history",
    }
    for _, opt := range opts {
        opt(m)
    }
    if m.ChatHistory == nil {
        m.ChatHistory = NewChatMessageHistory()
    }
    return m
}
代码语言:javascript
复制
func NewConversationTokenBuffer(
    llm llms.Model,
    maxTokenLimit int,
    options ...ConversationBufferOption,
) *ConversationTokenBuffer {
    tb := &ConversationTokenBuffer{
        LLM:                llm,
        MaxTokenLimit:      maxTokenLimit,
        ConversationBuffer: *applyBufferOptions(options...),
    }
    return tb
}

区别是后者添加了token数限制。

定义完记忆缓冲区后,接着就是初始化聊天会话

代码语言:javascript
复制
func NewConversation(llm llms.Model, memory schema.Memory) LLMChain {
    return LLMChain{
        Prompt: prompts.NewPromptTemplate(
            _conversationTemplate,
            []string{"history", "input"},
        ),
        LLM:          llm,
        Memory:       memory,
        OutputParser: outputparser.NewSimple(),
        OutputKey:    _llmChainDefaultOutputKey,
    }
}
代码语言:javascript
复制
const _llmChainDefaultOutputKey = "text"

可以看到,它定义了一个提示词模板,模板里面有history和input两个参数,具体模板内容如下:

代码语言:javascript
复制
const _conversationTemplate = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{{.history}}
Human: {{.input}}
AI:`

而LLMChain的定义如下:

代码语言:javascript
复制
type LLMChain struct {
    Prompt           prompts.FormatPrompter
    LLM              llms.Model
    Memory           schema.Memory
    CallbacksHandler callbacks.Handler
    OutputParser     schema.OutputParser[any]
    OutputKey string
}

最后通过llmchain的Run方法获取返回结果:

代码语言:javascript
复制
func Run(ctx context.Context, c Chain, input any, options ...ChainCallOption) (string, error) {
    inputKeys := c.GetInputKeys()
    memoryKeys := c.GetMemory().MemoryVariables(ctx)
    neededKeys := make([]string, 0, len(inputKeys))
    // Remove keys gotten from the memory.
    for _, inputKey := range inputKeys {
        isInMemory := false
        for _, memoryKey := range memoryKeys {
            if inputKey == memoryKey {
                isInMemory = true
                continue
            }
        }
        if isInMemory {
            continue
        }
        neededKeys = append(neededKeys, inputKey)
    }
    if len(neededKeys) != 1 {
        return "", ErrMultipleInputsInRun
    }
    outputKeys := c.GetOutputKeys()
    if len(outputKeys) != 1 {
        return "", ErrMultipleOutputsInRun
    }
    inputValues := map[string]any{neededKeys[0]: input}
    outputValues, err := Call(ctx, c, inputValues, options...)
    if err != nil {
        return "", err
    }
    outputValue, ok := outputValues[outputKeys[0]].(string)
    if !ok {
        return "", ErrWrongOutputTypeInRun
    }
    return outputValue, nil
}

在Run内部调用了Call方法,它先从memory里面取历史记录,然后根据选项筛选历史记录,接着调用LLM获取返回结果,并将结果记录到memory里面给下次使用

代码语言:javascript
复制
// Call is the standard function used for executing chains.
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
    fullValues := make(map[string]any, 0)
    for key, value := range inputValues {
        fullValues[key] = value
    }
    newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
    if err != nil {
        return nil, err
    }
    for key, value := range newValues {
        fullValues[key] = value
    }
    callbacksHandler := getChainCallbackHandler(c)
    if callbacksHandler != nil {
        callbacksHandler.HandleChainStart(ctx, inputValues)
    }
    outputValues, err := callChain(ctx, c, fullValues, options...)
    if err != nil {
        if callbacksHandler != nil {
            callbacksHandler.HandleChainError(ctx, err)
        }
        return outputValues, err
    }
    if callbacksHandler != nil {
        callbacksHandler.HandleChainEnd(ctx, outputValues)
    }
    if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
        return outputValues, err
    }
    return outputValues, nil
}

其中的callChain就是调用chain中的每个元素的call方法

代码语言:javascript
复制
func callChain(
    ctx context.Context,
    c Chain,
    fullValues map[string]any,
    options ...ChainCallOption,
) (map[string]any, error) {
    if err := validateInputs(c, fullValues); err != nil {
        return nil, err
    }
    outputValues, err := c.Call(ctx, fullValues, options...)
    if err != nil {
        return outputValues, err
    }
    if err := validateOutputs(c, outputValues); err != nil {
        return outputValues, err
    }
    return outputValues, nil
}

至此历史记录的相关源码介绍完毕。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-06-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档