package httpr import ( "context" "errors" "net/http" "strings" ) type path []string // newPath splits a path and ensures that it starts with a slash (/) and doesn't // have more than 1 catch-all parameter. func newPath(path string) (path, error) { if path[0] != '/' { return nil, errors.New("path should start with a slash (/) symbol") } if strings.Count(path, "*") > 1 { return nil, errors.New("there can be only one catch-all (*) parameter in path") } parts := strings.Split(strings.TrimSuffix(path, "/"), "/") parts[0] = "/" return parts, nil } // Params holds path parameters that are set as :key. type Params map[string]string type paramsKey struct{} // ParamsKey is used as a key for Params in a request's Context. var ParamsKey paramsKey = paramsKey{} type node struct { endpoint string children []*node handler http.HandlerFunc } func (n *node) get(path path, idx int) (http.HandlerFunc, Params) { // Check if this node is a catch-all endpoint. if n.endpoint[0] == '*' { var p Params = Params{} p[n.endpoint[1:]] = strings.Join(path[idx:], "/") return n.handler, p } // If this endpoint is a parameter, then add its name to a path's part. // This will be used further to fill Params. if n.endpoint[0] == ':' { path[idx] = n.endpoint + ":" + path[idx] } if len(path) == idx+1 { var params Params = make(Params) for _, part := range path { if part[0] == ':' { param := strings.Split(part[1:], ":") params[param[0]] = param[1] } } return n.handler, params } if len(path) > idx+1 { var wildcardOrParam *node for _, next := range n.children { if next.endpoint == path[idx+1] { return next.get(path, idx+1) } if next.endpoint[0] == ':' || next.endpoint[0] == '*' { wildcardOrParam = next } } if wildcardOrParam != nil { return wildcardOrParam.get(path, idx+1) } } return nil, nil } func (n *node) add(path path, idx int, handler http.HandlerFunc) error { // If it is a last part of path, then set a handler to this node. if len(path) == idx+1 { n.endpoint = path[idx] n.handler = handler return nil } // Check if next part is a parameter and if it is, then look for // an already existing endpoint with a different key. if path[idx+1][0] == '*' || path[idx+1][0] == ':' { for _, child := range n.children { if (child.endpoint[0] == '*' || child.endpoint[0] == ':') && path[idx+1] != child.endpoint { return errors.New("there is already a catch-all or regular param in there! You cannot add a second one") } } } // Check for an already existing endpoint. for _, child := range n.children { if child.endpoint == path[idx+1] { child.add(path, idx+1, handler) return nil } } // No endpoint was found. new_child := &node{endpoint: path[idx+1]} new_child.add(path, idx+1, handler) n.children = append(n.children, new_child) return nil } type Router struct { tree map[string]*node NotFoundHandler http.HandlerFunc } func New() *Router { return &Router{tree: make(map[string]*node)} } func (rr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tree, ok := rr.tree[r.Method]; ok { if r.URL.Path[0] != '/' { panic("first element of path should be a slash (/) symbol") } path, _ := newPath(r.URL.Path) if handler, params := tree.get(path, 0); handler != nil { if params != nil { r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params)) } handler(w, r) } else { if rr.NotFoundHandler != nil { rr.NotFoundHandler(w, r) } else { http.Error(w, "Not Found", http.StatusNotFound) } } } else { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) } } // Handler registers a handler for provided pattern for a given HTTP method. func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) error { path, err := newPath(pattern) if err != nil { return err } if rr.tree[method] == nil { rr.tree[method] = &node{endpoint: "/"} } if err := rr.tree[method].add(path, handler); err != nil { return err } return nil } // ServeStatic serves a given file system. // // Path should end with /*filepath to work. func (rr *Router) ServeStatic(path string, root http.FileSystem) { fileServer := http.FileServer(root) rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = Param(r, "filepath") fileServer.ServeHTTP(w, r) }) } // Param returns a parameter (that is set like `/a/b/:key/d`) inside a path // with a key or empty string if no such parameter found. func Param(r *http.Request, key string) string { if params := r.Context().Value(ParamsKey).(Params); params != nil { return params[key] } return "" }