package httpr import ( "context" "errors" "net/http" "strings" ) type path []string // newPath splits a path and ensures that it starts with a slash (/) and doesn't // have more than 1 catch-all parameter. func newPath(path string) (path, error) { if path[0] != '/' { return nil, errors.New("path should start with a slash (/) symbol") } if strings.Count(path, "*") > 1 { return nil, errors.New("there can be only one catch-all (*) parameter in path") } parts := strings.Split(strings.TrimSuffix(path, "/"), "/") parts[0] = "/" return parts, nil } // Params holds path parameters that are set as :key. type Params map[string]string type paramsKey struct{} // ParamsKey is used as a key for Params in a request's Context. var ParamsKey paramsKey = paramsKey{} type node struct { endpoint string children []*node handler http.HandlerFunc } func (n *node) get(path path) (http.HandlerFunc, Params) { pathLen := len(path) curNode := n outer: for i := range path { // Check if this node is a catch-all endpoint. if curNode.endpoint[0] == '*' { var p Params = Params{} p[curNode.endpoint[1:]] = strings.Join(path[i:], "/") return curNode.handler, p } // If this is a parametrised endpoint, then add its name to // a path's part. It will be used further to parse parameters. if curNode.endpoint[0] == ':' { path[i] = curNode.endpoint + ":" + path[i] } if pathLen == i+1 { var params Params = make(Params) for _, part := range path { if part[0] == ':' { param := strings.Split(part[1:], ":") params[param[0]] = param[1] } } return n.handler, params } if pathLen > i+1 { var paramNode *node for _, next := range n.children { if next.endpoint == path[i+1] { curNode = next continue outer } if next.endpoint[0] == ':' || next.endpoint[0] == '*' { paramNode = next } } if paramNode != nil { curNode = paramNode continue } } } return nil, nil } func (n *node) add(path path, handler http.HandlerFunc) error { pathLen := len(path) curNode := n outer: for i := range path { if pathLen == i+1 { if curNode.handler != nil { return errors.New("attempt to redefine a handler for already existing path") } curNode.endpoint = path[i] curNode.handler = handler return nil } for _, child := range curNode.children { firstChar := path[i+1][0] if (firstChar == ':' || firstChar == '*') && firstChar == child.endpoint[0] { curNode = child continue outer } if child.endpoint == path[i+1] { curNode = child continue outer } } newChild := &node{endpoint: path[i+1]} curNode.children = append(curNode.children, newChild) curNode = newChild } return nil } type Router struct { tree map[string]*node NotFoundHandler http.HandlerFunc } func New() *Router { return &Router{tree: make(map[string]*node)} } func (rr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tree, ok := rr.tree[r.Method]; ok { path, err := newPath(r.URL.Path) if err != nil { http.Error(w, err.Error(), http.StatusNotAcceptable) return } if handler, params := tree.get(path); handler != nil { if params != nil { r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params)) } handler(w, r) } else { if rr.NotFoundHandler != nil { rr.NotFoundHandler(w, r) } else { http.Error(w, "Not Found", http.StatusNotFound) } } } else { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) } } // Handler registers a handler for provided pattern for a given HTTP method. func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) error { path, err := newPath(pattern) if err != nil { return err } if rr.tree[method] == nil { rr.tree[method] = &node{endpoint: "/"} } if err := rr.tree[method].add(path, handler); err != nil { return err } return nil } // ServeStatic serves a given file system. // // Path should end with /*filepath to work. func (rr *Router) ServeStatic(path string, root http.FileSystem) error { fileServer := http.FileServer(root) return rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = Param(r, "filepath") fileServer.ServeHTTP(w, r) }) } // Param returns a parameter (that is set like `/a/b/:key/d`) inside a path // with a key or empty string if no such parameter found. func Param(r *http.Request, key string) string { if params := r.Context().Value(ParamsKey).(Params); params != nil { return params[key] } return "" }