首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >golang源码分析:net/rpc(1)

golang源码分析:net/rpc(1)

作者头像
golangLeetcode
发布2026-03-18 18:17:05
发布2026-03-18 18:17:05
540
举报

提到go语言的rpc大家习惯性和grpc-go画等号,其实不然,golang语言包里提供了自己的rpc实现,即net/rpc,下面我们通过例子分析下它的源码实现。

首先是server代码

代码语言:javascript
复制
package main
import (
    "errors"
    "log"
    "net"
    "net/http"
    "net/rpc"
)
// Args 定义RPC调用参数结构
type Args struct {
    A, B int
}
// MathService 定义数学运算服务
type MathService struct{}
// Add 实现加法运算方法
func (m *MathService) Add(args *Args, reply *int) error {
    if args == nil {
        return errors.New("参数不能为空")
    }
    *reply = args.A + args.B
    return nil
}
func main() {
    // 创建服务实例
    mathService := new(MathService)
    // 注册RPC服务
    err := rpc.Register(mathService)
    if err != nil {
        log.Fatal("注册服务失败:", err)
    }
    // 设置HTTP处理器
    rpc.HandleHTTP()
    // 启动HTTP服务
    listener, err := net.Listen("tcp", ":1234")
    if err != nil {
        log.Fatal("监听失败:", err)
    }
    log.Println("RPC服务已启动,监听端口1234...")
    err = http.Serve(listener, nil)
    if err != nil {
        log.Fatal("服务启动失败:", err)
    }
}

我们实现了一个service MathService,然后通过rpc.Register(mathService)进行了注册,然后设置了 rpc.HandleHTTP()处理器,最后就是listen。整体流程和grpc-go查不到,只是不用生成pb代码。

实现一个rpc竟然如此简单,下面我们分析下它的源码实现,首先看Server端,Register方法位于net/rpc/server.go

代码语言:javascript
复制
func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
代码语言:javascript
复制
func (server *Server) Register(rcvr any) error {
    return server.register(rcvr, "", false)
}
代码语言:javascript
复制
func (server *Server) register(rcvr any, name string, useName bool) error {
    s := new(service)
    s.typ = reflect.TypeOf(rcvr)
    s.rcvr = reflect.ValueOf(rcvr)
    sname := name
    if !useName {
        sname = reflect.Indirect(s.rcvr).Type().Name()    s.method = suitableMethods(s.typ, logRegisterError)
    if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
        return errors.New("rpc: service already defined: " + sname)
    }

通过反射拿到server的类型,然后再通过反射拿到方法列表,最后进行方法的注册

代码语言:javascript
复制
func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
    methods := make(map[string]*methodType)
    for m := 0; m < typ.NumMethod(); m++ {
        method := typ.Method(m)
        mtype := method.Type
        mname := method.Name
        methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}

其中的Server定义如下,serviceMap其实就是一个sync.Map,key就是service名称,值就是方法名称到方法描述结构体的映射。至此service注册流程完毕,就是保存服务的元信息。

代码语言:javascript
复制
type Server struct {
    serviceMap sync.Map   // map[string]*service
    reqLock    sync.Mutex // protects freeReq
    freeReq    *Request
    respLock   sync.Mutex // protects freeResp
    freeResp   *Response
}

HandleHttp方法定于如下

代码语言:javascript
复制
func HandleHTTP() {
    DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
代码语言:javascript
复制
const (
    // Defaults used by HandleHTTP
    DefaultRPCPath   = "/_goRPC_"
    DefaultDebugPath = "/debug/rpc"
)

如果请求的http路径是上面两个path就路由到刚刚注册的rpc服务处理

代码语言:javascript
复制
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
    http.Handle(rpcPath, server)
    http.Handle(debugPath, debugHTTP{server})
}

接着看下server的ServeHTTP方法,只处理CONNECT请求

代码语言:javascript
复制
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    if req.Method != "CONNECT" {
        w.Header().Set("Content-Type", "text/plain; charset=utf-8")
        w.WriteHeader(http.StatusMethodNotAllowed)
        io.WriteString(w, "405 must CONNECT\n")
        return
    }    
    conn, _, err := w.(http.Hijacker).Hijack()
    if err != nil {
        log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
        return
    }
    io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
    server.ServeConn(conn)
}
代码语言:javascript
复制
var connected = "200 Connected to Go RPC"

每建立一个连接都会启动一个gotoutine进行处理

代码语言:javascript
复制
func (server *Server) Accept(lis net.Listener) {
    for {
        conn, err := lis.Accept()
        if err != nil {
            log.Print("rpc.Serve: accept:", err.Error())
            return
        }
        go server.ServeConn(conn)
    }
}

处理方法详情如下,它的编码器和解码器都使用的是gob格式,这和pb不一样。

代码语言:javascript
复制
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
    buf := bufio.NewWriter(conn)
    srv := &gobServerCodec{
        rwc:    conn,
        dec:    gob.NewDecoder(conn),
        enc:    gob.NewEncoder(buf),
        encBuf: buf,
    }
    server.ServeCodec(srv)
}

接着是ServeCodec方法,在for循环里不断解码请求,然后启用一个协程 service.call进行处理:

代码语言:javascript
复制
func (server *Server) ServeCodec(codec ServerCodec) {
    sending := new(sync.Mutex)
    wg := new(sync.WaitGroup)
    for {
        service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
        if err != nil {
            if debugLog && err != io.EOF {
                log.Println("rpc:", err)
            }
            if !keepReading {
                break
            }
            // send a response if we actually managed to read a header.
            if req != nil {
                server.sendResponse(sending, req, invalidRequest, codec, err.Error())
                server.freeRequest(req)
            }
            continue
        }
        wg.Add(1)
        go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
    }
    // We've seen that there are no more requests.
    // Wait for responses to be sent before closing codec.
    wg.Wait()
    codec.Close()
}

接着看下call方法实现,本质就是调用了反射function.Call方法。

代码语言:javascript
复制
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
    if wg != nil {
        defer wg.Done()
    }
    mtype.Lock()
    mtype.numCalls++
    mtype.Unlock()
    function := mtype.method.Func
    // Invoke the method, providing a new value for the reply.
    returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
    // The return value for the method is an error.
    errInter := returnValues[0].Interface()
    errmsg := ""
    if errInter != nil {
        errmsg = errInter.(error).Error()
    }
    server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
    server.freeRequest(req)
}

那我们就好奇了,服务是怎么知道调用了那个Service的哪个Method呢?答案就在请求的解析方法里

代码语言:javascript
复制
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
    service, mtype, req, keepReading, err = server.readRequestHeader(codec)
    if err != nil {
        if !keepReading {
            return
        }
        // discard body
        codec.ReadRequestBody(nil)
        return
    }
    // Decode the argument value.
    argIsValue := false // if true, need to indirect before calling.
    if mtype.ArgType.Kind() == reflect.Pointer {
        argv = reflect.New(mtype.ArgType.Elem())
    } else {
        argv = reflect.New(mtype.ArgType)
        argIsValue = true
    }
    // argv guaranteed to be a pointer now.
    if err = codec.ReadRequestBody(argv.Interface()); err != nil {
        return
    }
    if argIsValue {
        argv = argv.Elem()
    }
    replyv = reflect.New(mtype.ReplyType.Elem())
    switch mtype.ReplyType.Elem().Kind() {
    case reflect.Map:
        replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
    case reflect.Slice:
        replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
    }
    return
}

它先解析请求的header,在ServiceMethod字段中解析出service名称和方法名称,然后在我们前面注册的Service信息里通过名称匹配到对应的Service和方法即可。ServiceMethod 是我们请求传递的参数,具体格式是:Service.Method,只需要通过点号切分即可得到。

代码语言:javascript
复制
func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
    // Grab the request header.
    req = server.getRequest()
    err = codec.ReadRequestHeader(req)
    if err != nil {
        req = nil
        if err == io.EOF || err == io.ErrUnexpectedEOF {
            return
        }
        err = errors.New("rpc: server cannot decode request: " + err.Error())
        return
    }
    // We read the header successfully. If we see an error now,
    // we can still recover and move on to the next request.
    keepReading = true
    dot := strings.LastIndex(req.ServiceMethod, ".")
    if dot < 0 {
        err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
        return
    }
    serviceName := req.ServiceMethod[:dot]
    methodName := req.ServiceMethod[dot+1:]
    // Look up the request.
    svci, ok := server.serviceMap.Load(serviceName)
    if !ok {
        err = errors.New("rpc: can't find service " + req.ServiceMethod)
        return
    }
    svc = svci.(*service)
    mtype = svc.method[methodName]
    if mtype == nil {
        err = errors.New("rpc: can't find method " + req.ServiceMethod)
    }
    return
}

其中请求定义如下

代码语言:javascript
复制
type Request struct {
    ServiceMethod string   // format: "Service.Method"
    Seq           uint64   // sequence number chosen by client
    next          *Request // for free list in Server
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-06-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档