
提到go语言的rpc大家习惯性和grpc-go画等号,其实不然,golang语言包里提供了自己的rpc实现,即net/rpc,下面我们通过例子分析下它的源码实现。
首先是server代码
package main
import (
"errors"
"log"
"net"
"net/http"
"net/rpc"
)
// Args 定义RPC调用参数结构
type Args struct {
A, B int
}
// MathService 定义数学运算服务
type MathService struct{}
// Add 实现加法运算方法
func (m *MathService) Add(args *Args, reply *int) error {
if args == nil {
return errors.New("参数不能为空")
}
*reply = args.A + args.B
return nil
}
func main() {
// 创建服务实例
mathService := new(MathService)
// 注册RPC服务
err := rpc.Register(mathService)
if err != nil {
log.Fatal("注册服务失败:", err)
}
// 设置HTTP处理器
rpc.HandleHTTP()
// 启动HTTP服务
listener, err := net.Listen("tcp", ":1234")
if err != nil {
log.Fatal("监听失败:", err)
}
log.Println("RPC服务已启动,监听端口1234...")
err = http.Serve(listener, nil)
if err != nil {
log.Fatal("服务启动失败:", err)
}
}我们实现了一个service MathService,然后通过rpc.Register(mathService)进行了注册,然后设置了 rpc.HandleHTTP()处理器,最后就是listen。整体流程和grpc-go查不到,只是不用生成pb代码。
实现一个rpc竟然如此简单,下面我们分析下它的源码实现,首先看Server端,Register方法位于net/rpc/server.go
func Register(rcvr any) error { return DefaultServer.Register(rcvr) }func (server *Server) Register(rcvr any) error {
return server.register(rcvr, "", false)
}func (server *Server) register(rcvr any, name string, useName bool) error {
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
sname := name
if !useName {
sname = reflect.Indirect(s.rcvr).Type().Name() s.method = suitableMethods(s.typ, logRegisterError)
if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
return errors.New("rpc: service already defined: " + sname)
}通过反射拿到server的类型,然后再通过反射拿到方法列表,最后进行方法的注册
func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
methods := make(map[string]*methodType)
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
mtype := method.Type
mname := method.Name
methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}其中的Server定义如下,serviceMap其实就是一个sync.Map,key就是service名称,值就是方法名称到方法描述结构体的映射。至此service注册流程完毕,就是保存服务的元信息。
type Server struct {
serviceMap sync.Map // map[string]*service
reqLock sync.Mutex // protects freeReq
freeReq *Request
respLock sync.Mutex // protects freeResp
freeResp *Response
}HandleHttp方法定于如下
func HandleHTTP() {
DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}const (
// Defaults used by HandleHTTP
DefaultRPCPath = "/_goRPC_"
DefaultDebugPath = "/debug/rpc"
)如果请求的http路径是上面两个path就路由到刚刚注册的rpc服务处理
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
http.Handle(rpcPath, server)
http.Handle(debugPath, debugHTTP{server})
}接着看下server的ServeHTTP方法,只处理CONNECT请求
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}var connected = "200 Connected to Go RPC"每建立一个连接都会启动一个gotoutine进行处理
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Print("rpc.Serve: accept:", err.Error())
return
}
go server.ServeConn(conn)
}
}处理方法详情如下,它的编码器和解码器都使用的是gob格式,这和pb不一样。
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{
rwc: conn,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
encBuf: buf,
}
server.ServeCodec(srv)
}接着是ServeCodec方法,在for循环里不断解码请求,然后启用一个协程 service.call进行处理:
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if debugLog && err != io.EOF {
log.Println("rpc:", err)
}
if !keepReading {
break
}
// send a response if we actually managed to read a header.
if req != nil {
server.sendResponse(sending, req, invalidRequest, codec, err.Error())
server.freeRequest(req)
}
continue
}
wg.Add(1)
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
}
// We've seen that there are no more requests.
// Wait for responses to be sent before closing codec.
wg.Wait()
codec.Close()
}接着看下call方法实现,本质就是调用了反射function.Call方法。
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
if wg != nil {
defer wg.Done()
}
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
function := mtype.method.Func
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
errmsg = errInter.(error).Error()
}
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
server.freeRequest(req)
}那我们就好奇了,服务是怎么知道调用了那个Service的哪个Method呢?答案就在请求的解析方法里
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
if err != nil {
if !keepReading {
return
}
// discard body
codec.ReadRequestBody(nil)
return
}
// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Pointer {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
// argv guaranteed to be a pointer now.
if err = codec.ReadRequestBody(argv.Interface()); err != nil {
return
}
if argIsValue {
argv = argv.Elem()
}
replyv = reflect.New(mtype.ReplyType.Elem())
switch mtype.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
}
return
}它先解析请求的header,在ServiceMethod字段中解析出service名称和方法名称,然后在我们前面注册的Service信息里通过名称匹配到对应的Service和方法即可。ServiceMethod 是我们请求传递的参数,具体格式是:Service.Method,只需要通过点号切分即可得到。
func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
// Grab the request header.
req = server.getRequest()
err = codec.ReadRequestHeader(req)
if err != nil {
req = nil
if err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
err = errors.New("rpc: server cannot decode request: " + err.Error())
return
}
// We read the header successfully. If we see an error now,
// we can still recover and move on to the next request.
keepReading = true
dot := strings.LastIndex(req.ServiceMethod, ".")
if dot < 0 {
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
return
}
serviceName := req.ServiceMethod[:dot]
methodName := req.ServiceMethod[dot+1:]
// Look up the request.
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc: can't find method " + req.ServiceMethod)
}
return
}其中请求定义如下
type Request struct {
ServiceMethod string // format: "Service.Method"
Seq uint64 // sequence number chosen by client
next *Request // for free list in Server
}本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!