1package ext
2
3import (
4 "github.com/fasthttp/router"
5 "github.com/valyala/fasthttp"
6)
7
8type (
9 Router struct {
10 middlewares []Middleware
11 fastRouter *router.Router
12 }
13 Middleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler
14 ErrorRequestHandler func(ctx *fasthttp.RequestCtx) error
15)
16
17func NewRouter(nestedRouter *router.Router) *Router {
18 return &Router{
19 fastRouter: nestedRouter,
20 }
21}
22
23func (self *Router) AddMiddleware(middleware Middleware) {
24 self.middlewares = append(self.middlewares, middleware)
25}
26
27func wrapError(next ErrorRequestHandler) fasthttp.RequestHandler {
28 return func(ctx *fasthttp.RequestCtx) {
29 if err := next(ctx); err != nil {
30 ctx.Response.SetStatusCode(500)
31 InternalServerError(ctx, err)
32 }
33 }
34}
35
36func (self *Router) run(next ErrorRequestHandler) fasthttp.RequestHandler {
37 return func(ctx *fasthttp.RequestCtx) {
38 req := wrapError(next)
39 for _, r := range self.middlewares {
40 req = r(req)
41 }
42 req(ctx)
43 }
44}
45
46func (self *Router) GET(path string, handler ErrorRequestHandler) {
47 self.fastRouter.GET(path, self.run(handler))
48}
49func (self *Router) POST(path string, handler ErrorRequestHandler) {
50 self.fastRouter.POST(path, self.run(handler))
51}