package httpr import ( "context" "errors" "net/http" "strings" ) type path []string // newPath ensures that a path provided is correct and splits it. func newPath(path string) (path, error) { pathLen := len(path) if pathLen == 0 { return nil, errors.New("empty path is not allowed") } if path[0] != '/' { return nil, errors.New("path should start with a slash symbol \"/\"") } if strings.Count(path, "*") > 1 { return nil, errors.New("path can have only one catch-all parameter \"*\"") } if path[pathLen-1] == '/' { path = path[:pathLen-1] } parts := strings.Split(path, "/") parts[0] = "/" return parts, nil } // newServePath is a reduced version of newPath for ServeHTTP. func newServePath(path string) (path, error) { if path[0] != '/' { return nil, errors.New("path should start with a slash symbol \"/\"") } path = strings.ReplaceAll(path, "//", "/") pathLen := len(path) if path[pathLen-1] == '/' { path = path[:pathLen-1] } parts := strings.Split(path, "/") parts[0] = "/" return parts, nil } // Params holds path parameters that are set as :key. type Params map[string]string type paramsKey struct{} // ParamsKey is used as a key for Params in a request's Context. var ParamsKey paramsKey = paramsKey{} type node struct { endpoint string children []*node handler http.HandlerFunc } func (n *node) get(path path) (http.HandlerFunc, Params) { pathLen := len(path) curNode := n outer: for i := range path { // Check if this node is a catch-all endpoint. if curNode.endpoint[0] == '*' { var p Params = Params{} p[curNode.endpoint[1:]] = strings.Join(path[i:], "/") return curNode.handler, p } // If this is a parametrised endpoint, then add its name to // a path's part. It will be used further to parse parameters. if curNode.endpoint[0] == ':' { path[i] = curNode.endpoint + ":" + path[i] } if pathLen == i+1 { var params Params = make(Params) for _, part := range path { if part[0] == ':' { param := strings.Split(part[1:], ":") params[param[0]] = param[1] } } return curNode.handler, params } if pathLen > i+1 { var paramNode *node if len(curNode.children) == 0 { break outer } for _, next := range curNode.children { if next.endpoint == path[i+1] { curNode = next continue outer } if next.endpoint[0] == ':' || next.endpoint[0] == '*' { paramNode = next } } if paramNode != nil { curNode = paramNode continue outer } break outer } } return nil, nil } func (n *node) add(path path, handler http.HandlerFunc) error { pathLen := len(path) curNode := n outer: for i := range path { if pathLen == i+1 { if curNode.handler != nil { return errors.New("attempt to redefine a handler for already existing path") } curNode.endpoint = path[i] curNode.handler = handler return nil } for _, child := range curNode.children { firstChar := path[i+1][0] if (firstChar == ':' || firstChar == '*') && firstChar == child.endpoint[0] { // Do not allow different param names, because only the first one // is saved, so a param won't be available by a new name. // Therefore, it is good to return an error because in this case // you're doing something wrong. if path[i+1] != child.endpoint { return errors.New("param names " + path[i+1] + " and " + child.endpoint + " are differ") } curNode = child continue outer } if child.endpoint == path[i+1] { curNode = child continue outer } } newChild := &node{endpoint: path[i+1]} curNode.children = append(curNode.children, newChild) curNode = newChild } return nil } type Router struct { tree map[string]*node NotFoundHandler http.HandlerFunc } func New() *Router { return &Router{tree: make(map[string]*node)} } func (rr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { tree, ok := rr.tree[r.Method] if !ok { http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return } path, err := newServePath(r.URL.Path) if err != nil { http.Error(w, err.Error(), http.StatusNotAcceptable) return } if handler, params := tree.get(path); handler != nil { if params != nil { r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params)) } handler(w, r) } else { if rr.NotFoundHandler != nil { rr.NotFoundHandler(w, r) } else { http.Error(w, "Not Found", http.StatusNotFound) } } } // Handler registers a handler for provided pattern for a given HTTP method. // Pattern must start with a slash (/) symbol. func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) error { path, err := newPath(pattern) if err != nil { return err } if rr.tree[method] == nil { rr.tree[method] = &node{endpoint: "/"} } return rr.tree[method].add(path, handler) } // ServeStatic serves a given file system. // // Path should end with /*filepath to work. func (rr *Router) ServeStatic(path string, root http.FileSystem) error { fileServer := http.FileServer(root) return rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = Param(r, "filepath") fileServer.ServeHTTP(w, r) }) } // subPath attaches a base path in front of a pattern. // // It is not a sub-router, it just passes a resulted pattern down to // a router instance. type subPath struct { r *Router base string } // Sub creates a group of handlers with the same base path. // // How to use: // // r := httpr.New() // ... // s := r.Sub("/api/v1") // s.Handler(http.MethodGet, "/", func(w, r) {...}) // s.Handler(http.MethodGet, "/section", func(w, r) {...}) func (rr *Router) Sub(base string) *subPath { if base[len(base)-1] == '/' { base = base[:len(base)-1] } return &subPath{ r: rr, base: base, } } // Handler registers a handler for a sub-path. func (sp *subPath) Handler(method, pattern string, handler http.HandlerFunc) error { return sp.r.Handler(method, sp.base+pattern, handler) } // Param returns a URL parameter set with :key, or an empty string if not found. func Param(r *http.Request, key string) string { if params := r.Context().Value(ParamsKey).(Params); params != nil { return params[key] } return "" }