Gin源码分析 – 中间件 – 源码分析

1 介绍

上一篇文章对中间件的使用方法以及如何开发中间件进行了简要的概述,本文主要对Gin的中间件的相关代码进行分析,从而对中间件有更深入的理解。

2 数据结构

2.1 RouterGroup

// RouterGroup is used internally to configure router, 
// a RouterGroup is associated with
// a prefix and an array of handlers (middleware).
type RouterGroup struct {
	Handlers HandlersChain
	basePath string
	engine   *Engine
	root     bool
}

RouterGroup管理一个路由分组

(1)HandlersChain,用于管理该分组中注册的中间件,HandlersChain实际上是Handlers的一个切片,定义如下:

// HandlersChain defines a HandlerFunc array.
type HandlersChain []HandlerFunc

HandlerFunc实际上是中间件返回的一个处理方法,定义如下:

// HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context)

(2)root,如果是真则是全局路由组,该路由实际上是Engine的组成部分,定义如下:

type Engine struct {
	RouterGroup
        ...
}

(3)basePath,路由组的前缀URL;

(4)engine,Enginer指针。

3 注册中间件

3.1 全局注册

从我们最常用的函数gin.Default()开始,它内部构造一个新的engine之后就通过Use()函数注册了Logger中间件和Recovery中间件,代码如下:

// Default returns an Engine instance with the 
// Logger and Recovery middleware already attached.
func Default() *Engine {
	debugPrintWARNINGDefault()
	engine := New()
	engine.Use(Logger(), Recovery())
	return engine
}

在上述代码中,Engine的Use方法实现了中间件的全局注册,方法如下:

// Use attaches a global middleware to the router. 
// ie. the middleware attached though Use() will be
// included in the handlers chain for every single request. 
// Even 404, 405, static files...
// For example, this is the right place for a 
// logger or error management middleware.
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
	engine.RouterGroup.Use(middleware...)
	engine.rebuild404Handlers()
	engine.rebuild405Handlers()
	return engine
}

实际上调用了RouterGroup的Use方法,该方法定义如下:

// Use adds middleware to the group, see example code in GitHub.
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
	group.Handlers = append(group.Handlers, middleware...)
	return group.returnObj()
}

首先将注册的多个中间件,添加到group的Handlers中,然后调用returnObj返回 IRoutes接口,该方法如下:

func (group *RouterGroup) returnObj() IRoutes {
	if group.root {
		return group.engine
	}
	return group
}

如果是全局路由组,则返回Engine,否则返回自身。IRoutes接口定义如下:

// IRoutes defines all router handle interface.
type IRoutes interface {
	Use(...HandlerFunc) IRoutes

	Handle(string, string, ...HandlerFunc) IRoutes
	Any(string, ...HandlerFunc) IRoutes
	GET(string, ...HandlerFunc) IRoutes
	POST(string, ...HandlerFunc) IRoutes
	DELETE(string, ...HandlerFunc) IRoutes
	PATCH(string, ...HandlerFunc) IRoutes
	PUT(string, ...HandlerFunc) IRoutes
	OPTIONS(string, ...HandlerFunc) IRoutes
	HEAD(string, ...HandlerFunc) IRoutes

	StaticFile(string, string) IRoutes
	Static(string, string) IRoutes
	StaticFS(string, http.FileSystem) IRoutes
}

RouterGroup实现了所有这些方法,后续将展开分析,Engine中嵌入了RouterGroup类型,因此也实现了IRoutes接口。

3.2 局部注册

通过r.Group方法可以创建了一个路由分组,在创建的过程中可以直接注册本路由分组需要使用的中间件,代码如下:

r := gin.New()
user := r.Group("user", gin.Logger(), gin.Recovery())
{
	user.GET("info", func(context *gin.Context) {

	})
	user.GET("article", func(context *gin.Context) {

	})
}

// Group creates a new router group. 
// You should add all the routes that have common middlewares 
// or the same path prefix.
// For example, all the routes that use a common middleware for 
// authorization could be grouped.
func (group *RouterGroup) Group(
  relativePath string, 
  handlers ...HandlerFunc,
) *RouterGroup {
	return &RouterGroup{
		Handlers: group.combineHandlers(handlers),
		basePath: group.calculateAbsolutePath(relativePath),
		engine:   group.engine,
	}
}

RouterGroup的Group的方法实际上是创建了一个子路由组,并且可以直接进行中间件的组成工作:

(1)首先调用 combineHandlers将待注册中间件和父亲中间件进行合并;

(2)然后计算绝对路由;

(3)设置engine。

func (group *RouterGroup) combineHandlers(
  handlers HandlersChain,
) HandlersChain {
	finalSize := len(group.Handlers) + len(handlers)
	if finalSize >= int(abortIndex) {
		panic("too many handlers")
	}
	mergedHandlers := make(HandlersChain, finalSize)
	copy(mergedHandlers, group.Handlers)
	copy(mergedHandlers[len(group.Handlers):], handlers)
	return mergedHandlers
}

const abortIndex int8 = math.MaxInt8 / 2    // 127 / 2 = 63

(1)首先计算父亲中的中间件的数量,然后和待注册中间件的数量进行累加,一个路由组中注册的中间件数量不能超过63个;

(2)创建相应的数组,首先添加父路由组中的中间件,然后添加待注册的中间件。

3.3 单路由注册

在注册路由处理方法的时候也可以注册中间件,见下面的例子。

r.GET("/hello", gin.Logger(), gin.Recovery(), func(c *gin.Context) {
})

其实,我们自定义的处理函数,也是中间件,不过是最后一个中间件罢了,具体实现如下:

// GET is a shortcut for router.Handle("GET", path, handle).
func (group *RouterGroup) GET(
  relativePath string, 
  handlers ...HandlerFunc,
) IRoutes {
	return group.handle(http.MethodGet, relativePath, handlers)
}

func (group *RouterGroup) handle(
  httpMethod, 
  relativePath string, 
  handlers HandlersChain,
) IRoutes {
	absolutePath := group.calculateAbsolutePath(relativePath)
	handlers = group.combineHandlers(handlers)
	group.engine.addRoute(httpMethod, absolutePath, handlers)
	return group.returnObj()
}

首先GET方法内部调用了handle方法,在handle方法内容完成如下的工作:

(1)计算绝对请求路由;

(2)将路由组中已注册的中间件和待注册的中间以及处理函数进行合并;

(3)将合并后的处理方法数组调价的engine的路由树中,此函数在此不再展开,本系列中由专门的章节进行论述。

4 中间件的执行

4.1 上下文

每当有一个请求到达服务端时,Gin就会为这个请求分配一个Context, 该上下文中保存了这个请求对应的处理器链(中间件数组), 以及一个索引index用于记录当前处理到了哪个HandlerFunc,index的初始值为-1。

type Context struct {
        ...
	handlers HandlersChain
	index    int8
        ...
}

4.2 执行流程

在Gin中当接收到一个HTTP请求后,通过ServeHTTP方法完成对一个HTTP请求的处理。

// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(
  w http.ResponseWriter, 
  req *http.Request,
) {
	c := engine.pool.Get().(*Context)
	c.writermem.reset(w)
	c.Request = req
	c.reset()

	engine.handleHTTPRequest(c)

	engine.pool.Put(c)
}

在上面的代码中,首先创建一个上下文,然后通过handleHTTPRequest处理HTTP请求。

func (engine *Engine) handleHTTPRequest(c *Context) {
	// (1)获取请求方法和请求URL
        httpMethod := c.Request.Method
	rPath := c.Request.URL.Path
        ....
	
        // Find root of the tree for the given HTTP method
	t := engine.trees
	for i, tl := 0, len(t); i < tl; i++ 

                // (2)获取请求处理器链
		if value.handlers != nil {
			c.handlers = value.handlers
			c.fullPath = value.fullPath
                        // (3)通过c.Next()开始执行
			c.Next()
			c.writermem.WriteHeaderNow()
			return
		}
	}

        ...
}

这个方法比较长,在此处主要对主要逻辑进行分析,

(1)获取处理请求的类型,GET,POST等;获取请求URL;

(2)从Engie.trees中查找该请求的处理器,如果找到则将处理器链(中间件数组)复制到上下文的handlers中;

(3)调用c.Next()方法开始中间件的执行,可以看出整个请求的流程主要由c.Next()这个方法控制。

4.3 c.Next方法

// Next should be used only inside middleware.
// It executes the pending handlers in the chain 
// inside the calling handler.
// See example in GitHub.
func (c *Context) Next() {
	c.index++
	for c.index < int8(len(c.handlers)) {
		c.handlers[c.index](c)
		c.index++
	}
}

这个方法很简单,就是循环调用每一个中间件,但是由于每个中间件中还是会调用Next方法,因此理解起来就略有复杂了。下面以一个例子说明这个执行过程。

func indexHandler(c *gin.Context)  {
	fmt.Println("index")
	c.JSON(http.StatusOK, gin.H{
		"msg": "index",
	})
}

func m1(c *gin.Context) {
	fmt.Println("m1 in ...")
	c.Next() //调用后续的处理函数
	fmt.Printf("m1 out ...")
}

func m2(c *gin.Context)  {
	fmt.Println("m2 in ...")
	c.Next()  //调用后续的处理函数
	fmt.Println("m2 out ...")
}

func main()  {
	r := gin.New()
	r.GET("/index", m1, m2, indexHandler)
	r.Run(":8080")
}

运行上面这个例子,然后发起一个index请求,结果如下

m1 in ...
m2 in ...
index
m2 out ...
m1 out ...

析过程如下:

(1)handleHTTPRequest中调用c.Next(),此时index=0,因此获取第一个处理方法m1,并执行;

(2)m1中首先输出”m1 in…”;

(3)m1中调用c.Next(),此时index=1,获取第二个处理方法m2,并执行;

(4)m2中首先输出”m2 in…”;

(5)m2中调用c.Next(),此时index=2,获取第三个处理方法indexHandler,并执行;

(6)indexHandler完成函数执行,生成响应应答保存到上下文中,然后退出,返回m2中的Next继续执行;

(7)m2中index=3,已经大于等于整个处理器的数量了,从Next()中退出,然后输出”m2 out…”,然后退出,返回m1中的Next继续执行;

(8)m1中index=4,已经大于等于整个处理器的数量了,从Next()中退出,然后输出”m1 out…”,然后退出,返回handleHTTPReqest中的Next继续执行;

(9)handleHTTPReqest中index=5,已经大于等于整个处理器的数量了,从Next()中退出,然后向客户端发送HTTP响应,完成整个HTTP的执行流程。

4.4 c.Abort方法

当某一个中间件在执行的过程中如果出现异常,不在需要执行下一个处理方法时,可以调用Abort方法。

// Abort prevents pending handlers from being called. 
// Note that this will not stop the current handler.
// Let's say you have an authorization middleware that validates 
// that the current request is authorized.
// If the authorization fails (ex: the password does not match), 
// call Abort to ensure the remaining handlers
// for this request are not called.
func (c *Context) Abort() {
	c.index = abortIndex
}

// IsAborted returns true if the current context was aborted.
func (c *Context) IsAborted() bool {
	return c.index >= abortIndex
}

通过上面的分析可以看出整个执行控制是依靠index的数值进行控制的,当index大于当前的处理器的数量时则直接退出,在Gin定义了一个abortIndex,这个值是63,同时这个值也表示了一个处理器中最多能注册的中间件。

// abortIndex represents a typical value used in abort functions.
const abortIndex int8 = math.MaxInt8 >> 1 // 63

func (group *RouterGroup) combineHandlers(
  handlers HandlersChain,
) HandlersChain {
	finalSize := len(group.Handlers) + len(handlers)
  //最多的处理器的数量
	if finalSize >= int(abortIndex) {
		panic("too many handlers")
	}
	mergedHandlers := make(HandlersChain, finalSize)
	copy(mergedHandlers, group.Handlers)
	copy(mergedHandlers[len(group.Handlers):], handlers)
	return mergedHandlers
}

Abort还有如下几个变形,方便开发人员使用。

(1)AbortWithStatus,返回的时候设置HTTP响应码;

(2)AbortWithStatusJSON,返回的时候可以设置HTTP相应码、JSON数据;

(3)AbortWithError,返回的时候可以设置HTTP响应码和异常信息。

// AbortWithStatus calls `Abort()` and writes 
// the headers with the specified status code.
// For example, a failed attempt to authenticate a request 
// could use: context.AbortWithStatus(401).
func (c *Context) AbortWithStatus(code int) {
	c.Status(code)
	c.Writer.WriteHeaderNow()
	c.Abort()
}

// AbortWithStatusJSON calls `Abort()` and then `JSON` internally.
// This method stops the chain, writes the status code 
// and return a JSON body.
// It also sets the Content-Type as "application/json".
func (c *Context) AbortWithStatusJSON(code int, jsonObj interface{}) {
	c.Abort()
	c.JSON(code, jsonObj)
}

// AbortWithError calls `AbortWithStatus()` and `Error()` internally.
// This method stops the chain, writes the status code 
// and pushes the specified error to `c.Errors`.
// See Context.Error() for more details.
func (c *Context) AbortWithError(code int, err error) *Error {
	c.AbortWithStatus(code)
	return c.Error(err)
}

4.5 c.Get方法和c.Set方法

c.Set()和c.Get()这两个方法多用于在多个函数之间通过c传递数据的。比如我们可以在认证中间件中获取当前请求的相关信息(userID等)通过c.Set()存入c;然后在后续处理业务逻辑的函数中通过c.Get()来获取当前请求的用户 这个过程。

// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes  c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
	c.mu.Lock()
	if c.Keys == nil {
		c.Keys = make(map[string]interface{})
	}

	c.Keys[key] = value
	c.mu.Unlock()
}

// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
	c.mu.RLock()
	value, exists = c.Keys[key]
	c.mu.RUnlock()
	return
}

5 总结

整个中间件的处理流程设计的还是比较的精巧的,后面将分析一些常用的中间件的具体实现。