1
0
Fork 0
httpr/httpr.go

191 lines
4.6 KiB
Go

package httpr
import (
"context"
"errors"
"net/http"
"strings"
)
type path []string
// newPath splits a path and ensures that it starts with a slash (/).
func newPath(path string) (path, error) {
parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
if parts[0] != "" {
return nil, errors.New("path should start with a slash (/) symbol")
}
parts[0] = "/"
return parts, nil
}
// Params holds path parameters that are set as :key.
type Params map[string]string
type paramsKey struct{}
// ParamsKey is used as a key for Params in a request's Context.
var ParamsKey paramsKey = paramsKey{}
type node struct {
endpoint string
children []*node
handler http.HandlerFunc
}
func (n *node) get(path path, idx int) (http.HandlerFunc, Params) {
// Check if this node is a catch-all endpoint.
if n.endpoint[0] == '*' {
var p Params = Params{}
p[n.endpoint[1:]] = strings.Join(path[idx:], "/")
return n.handler, p
}
// If this endpoint is a parameter, then add its name to a path's part.
// This will be used further to fill Params.
if n.endpoint[0] == ':' {
path[idx] = n.endpoint + ":" + path[idx]
}
if len(path) == idx+1 {
var params Params = make(Params)
for _, part := range path {
if part[0] == ':' {
param := strings.Split(part[1:], ":")
params[param[0]] = param[1]
}
}
return n.handler, params
}
if len(path) > idx+1 {
var wildcardOrParam *node
for _, next := range n.children {
if next.endpoint == path[idx+1] {
return next.get(path, idx+1)
}
if next.endpoint[0] == ':' || next.endpoint[0] == '*' {
wildcardOrParam = next
}
}
if wildcardOrParam != nil {
return wildcardOrParam.get(path, idx+1)
}
}
return nil, nil
}
func (n *node) add(path path, idx int, handler http.HandlerFunc) error {
// If it is a last part of path, then set a handler to this node.
if len(path) == idx+1 {
n.endpoint = path[idx]
n.handler = handler
return nil
}
// Check if next part is a parameter and if it is, then look for
// an already existing endpoint with a different key.
if path[idx+1][0] == '*' || path[idx+1][0] == ':' {
for _, child := range n.children {
if (child.endpoint[0] == '*' || child.endpoint[0] == ':') && path[idx+1] != child.endpoint {
return errors.New("there is already a catch-all or regular param in there! You cannot add a second one")
}
}
}
// Check for an already existing endpoint.
for _, child := range n.children {
if child.endpoint == path[idx+1] {
child.add(path, idx+1, handler)
return nil
}
}
// No endpoint was found.
new_child := &node{endpoint: path[idx+1]}
new_child.add(path, idx+1, handler)
n.children = append(n.children, new_child)
return nil
}
type Router struct {
tree map[string]*node
NotFoundHandler http.HandlerFunc
}
func New() *Router {
return &Router{tree: make(map[string]*node)}
}
func (rr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if tree, ok := rr.tree[r.Method]; ok {
if r.URL.Path[0] != '/' {
panic("first element of path should be a slash (/) symbol")
}
path, _ := newPath(r.URL.Path)
if handler, params := tree.get(path, 0); handler != nil {
if params != nil {
r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params))
}
handler(w, r)
} else {
if rr.NotFoundHandler != nil {
rr.NotFoundHandler(w, r)
} else {
http.Error(w, "Not Found", http.StatusNotFound)
}
}
} else {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
}
}
// Handler registers a handler for provided pattern for a given HTTP method.
func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) {
if pattern[0] != '/' {
panic("first element of path should be a slash (/) symbol")
}
if strings.Count(pattern, "*") > 1 {
panic("there can be only one wildcard (*) symbol in path")
}
if rr.tree[method] == nil {
rr.tree[method] = &node{endpoint: "/"}
}
path, _ := newPath(pattern)
rr.tree[method].add(path, 0, handler)
}
// ServeStatic serves a given file system.
//
// Path should end with /*filepath to work.
func (rr *Router) ServeStatic(path string, root http.FileSystem) {
fileServer := http.FileServer(root)
rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = Param(r, "filepath")
fileServer.ServeHTTP(w, r)
})
}
// Param returns a parameter (that is set like `/a/b/:key/d`) inside a path
// with a key or empty string if no such parameter found.
func Param(r *http.Request, key string) string {
if params := r.Context().Value(ParamsKey).(Params); params != nil {
return params[key]
}
return ""
}