Compare commits
32 Commits
Author | SHA1 | Date | |
---|---|---|---|
32ae3a3d0d | |||
32e7468eef | |||
3e6d03db1a | |||
c237f8c566 | |||
a23264b00f | |||
b4163d2162 | |||
5d613b34ee | |||
cc2cd72df8 | |||
92692454da | |||
a2cb6182e8 | |||
aba211f3ec | |||
d53622908b | |||
7c8baeecf5 | |||
468606e4fd | |||
d9e5024d4d | |||
c68d7b324a | |||
3cb32c5ec9 | |||
0717a2e3d3 | |||
e25a8a42c3 | |||
bc11a46806 | |||
33de30fe23 | |||
5d6a3630c6 | |||
fcbd09506a | |||
2bfaae11f2 | |||
257bd7ea76 | |||
653bae85f5 | |||
4f54ab4156 | |||
9cda541108 | |||
89c2333a4f | |||
538f1bd676 | |||
a0b80ced85 | |||
99a7cebd0a |
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) <year> <copyright holders>
|
||||
Copyright (c) 2023 Alexander "Arav" Andreev <me@arav.su>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
|
40
README.md
40
README.md
@ -3,19 +3,33 @@
|
||||
It is an implementation of yet another HTTP router.
|
||||
|
||||
The reason why this router was made is to be able to have pretty paths with
|
||||
parameters and standard endpoints at the same level.
|
||||
parameters and regular endpoints at the same level. Like this:
|
||||
|
||||
As an example here is a structure used in my another project
|
||||
(dwelling-upload):
|
||||
|
||||
GET /
|
||||
POST /
|
||||
GET /:hash/:name
|
||||
POST /delete
|
||||
DELETE /:hash
|
||||
GET /:a/:b
|
||||
GET /assets/*filepath
|
||||
GET /robots.txt
|
||||
GET /favicon.svg
|
||||
|
||||
Previously I used httprouter and I had to have `/f/:hash/:name` route
|
||||
instead of just `/:hash/:name` because of collisions.
|
||||
In routers like httprouter this is not allowed.
|
||||
|
||||
This router is used like many others., example:
|
||||
|
||||
r := httpr.New()
|
||||
|
||||
r.Handler(http.MethodGet, "/", func(w http.ResponseWriter, r *http.Request) {
|
||||
...
|
||||
})
|
||||
|
||||
r.ServeStatic("/assets/*filepath", http.FS(os.Dir(".")))
|
||||
|
||||
r.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) {
|
||||
...
|
||||
}
|
||||
|
||||
s := r.Sub("/api/v1")
|
||||
|
||||
s.Handler(http.MethodGet, "/", func(w http.ResponseWriter, r *http.Request) {
|
||||
...
|
||||
})
|
||||
|
||||
if err := http.ListenAndServe(":8000", r); err != nil {
|
||||
...
|
||||
}
|
||||
|
288
httpr.go
288
httpr.go
@ -9,13 +9,50 @@ import (
|
||||
|
||||
type path []string
|
||||
|
||||
// newPath splits a path and ensures that it starts with a slash (/).
|
||||
// newPath ensures that a path provided is correct and splits it.
|
||||
func newPath(path string) (path, error) {
|
||||
parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
|
||||
if parts[0] != "" {
|
||||
return nil, errors.New("path should start with a slash (/) symbol")
|
||||
pathLen := len(path)
|
||||
|
||||
if pathLen == 0 {
|
||||
return nil, errors.New("empty path is not allowed")
|
||||
}
|
||||
|
||||
if path[0] != '/' {
|
||||
return nil, errors.New("path should start with a slash symbol \"/\"")
|
||||
}
|
||||
|
||||
if strings.Count(path, "*") > 1 {
|
||||
return nil, errors.New("path can have only one catch-all parameter \"*\"")
|
||||
}
|
||||
|
||||
if path[pathLen-1] == '/' {
|
||||
path = path[:pathLen-1]
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
|
||||
parts[0] = "/"
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// newServePath is a reduced version of newPath for ServeHTTP.
|
||||
func newServePath(path string) (path, error) {
|
||||
if path[0] != '/' {
|
||||
return nil, errors.New("path should start with a slash symbol \"/\"")
|
||||
}
|
||||
|
||||
path = strings.ReplaceAll(path, "//", "/")
|
||||
|
||||
pathLen := len(path)
|
||||
if path[pathLen-1] == '/' {
|
||||
path = path[:pathLen-1]
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
|
||||
parts[0] = "/"
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
@ -33,83 +70,111 @@ type node struct {
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
func (n *node) get(path path, idx int) (http.HandlerFunc, Params) {
|
||||
// Check if this node is a catch-all endpoint.
|
||||
if n.endpoint[0] == '*' {
|
||||
var p Params = Params{}
|
||||
p[n.endpoint[1:]] = strings.Join(path[idx:], "/")
|
||||
return n.handler, p
|
||||
}
|
||||
func (n *node) get(path path) (http.HandlerFunc, Params) {
|
||||
pathLen := len(path)
|
||||
curNode := n
|
||||
|
||||
// If this endpoint is a parameter, then add its name to a path's part.
|
||||
// This will be used further to fill Params.
|
||||
if n.endpoint[0] == ':' {
|
||||
path[idx] = n.endpoint + ":" + path[idx]
|
||||
}
|
||||
|
||||
if len(path) == idx+1 {
|
||||
var params Params = make(Params)
|
||||
|
||||
for _, part := range path {
|
||||
if part[0] == ':' {
|
||||
param := strings.Split(part[1:], ":")
|
||||
params[param[0]] = param[1]
|
||||
}
|
||||
outer:
|
||||
for i := range path {
|
||||
// Check if this node is a catch-all endpoint.
|
||||
if curNode.endpoint[0] == '*' {
|
||||
var p Params = Params{}
|
||||
p[curNode.endpoint[1:]] = strings.Join(path[i:], "/")
|
||||
return curNode.handler, p
|
||||
}
|
||||
|
||||
return n.handler, params
|
||||
}
|
||||
|
||||
if len(path) > idx+1 {
|
||||
var wildcardOrParam *node
|
||||
for _, next := range n.children {
|
||||
if next.endpoint == path[idx+1] {
|
||||
return next.get(path, idx+1)
|
||||
}
|
||||
|
||||
if next.endpoint[0] == ':' || next.endpoint[0] == '*' {
|
||||
wildcardOrParam = next
|
||||
}
|
||||
// If this is a parametrised endpoint, then add its name to
|
||||
// a path's part. It will be used further to parse parameters.
|
||||
if curNode.endpoint[0] == ':' {
|
||||
path[i] = curNode.endpoint + ":" + path[i]
|
||||
}
|
||||
|
||||
if wildcardOrParam != nil {
|
||||
return wildcardOrParam.get(path, idx+1)
|
||||
pathNextIdx := i + 1
|
||||
|
||||
if pathLen == pathNextIdx {
|
||||
var params Params = make(Params)
|
||||
|
||||
for _, part := range path {
|
||||
if part[0] == ':' {
|
||||
param := strings.Split(part[1:], ":")
|
||||
params[param[0]] = param[1]
|
||||
}
|
||||
}
|
||||
|
||||
return curNode.handler, params
|
||||
}
|
||||
|
||||
if pathLen > pathNextIdx {
|
||||
if len(curNode.children) == 0 {
|
||||
break outer
|
||||
}
|
||||
|
||||
var paramNode *node
|
||||
|
||||
for _, next := range curNode.children {
|
||||
if next.endpoint == path[pathNextIdx] {
|
||||
curNode = next
|
||||
continue outer
|
||||
}
|
||||
|
||||
if next.endpoint[0] == ':' || next.endpoint[0] == '*' {
|
||||
paramNode = next
|
||||
}
|
||||
}
|
||||
|
||||
if paramNode != nil {
|
||||
curNode = paramNode
|
||||
continue outer
|
||||
}
|
||||
|
||||
break outer
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (n *node) add(path path, idx int, handler http.HandlerFunc) error {
|
||||
// If it is a last part of path, then set a handler to this node.
|
||||
if len(path) == idx+1 {
|
||||
n.endpoint = path[idx]
|
||||
n.handler = handler
|
||||
return nil
|
||||
}
|
||||
func (n *node) add(path path, handler http.HandlerFunc) error {
|
||||
pathLastIdx := len(path) - 1
|
||||
curNode := n
|
||||
|
||||
// Check if next part is a parameter and if it is, then look for
|
||||
// an already existing endpoint with a different key.
|
||||
if path[idx+1][0] == '*' || path[idx+1][0] == ':' {
|
||||
for _, child := range n.children {
|
||||
if (child.endpoint[0] == '*' || child.endpoint[0] == ':') && path[idx+1] != child.endpoint {
|
||||
return errors.New("there is already a catch-all or regular param in there! You cannot add a second one")
|
||||
outer:
|
||||
for i := range path {
|
||||
if pathLastIdx == i {
|
||||
if curNode.handler != nil {
|
||||
return errors.New("attempt to redefine a handler for already existing path")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for an already existing endpoint.
|
||||
for _, child := range n.children {
|
||||
if child.endpoint == path[idx+1] {
|
||||
child.add(path, idx+1, handler)
|
||||
curNode.endpoint = path[i]
|
||||
curNode.handler = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// No endpoint was found.
|
||||
new_child := &node{endpoint: path[idx+1]}
|
||||
new_child.add(path, idx+1, handler)
|
||||
n.children = append(n.children, new_child)
|
||||
pathNextIdx := i + 1
|
||||
|
||||
for _, child := range curNode.children {
|
||||
firstChar := path[pathNextIdx][0]
|
||||
if (firstChar == ':' || firstChar == '*') && firstChar == child.endpoint[0] {
|
||||
// Do not allow different param names, because only the first one
|
||||
// is saved, so a param won't be available by a new name.
|
||||
// Therefore, it is good to return an error because in this case
|
||||
// you're doing something wrong.
|
||||
if path[pathNextIdx] != child.endpoint {
|
||||
return errors.New("param names " + path[pathNextIdx] + " and " + child.endpoint + " are differ")
|
||||
}
|
||||
curNode = child
|
||||
continue outer
|
||||
}
|
||||
|
||||
if child.endpoint == path[pathNextIdx] {
|
||||
curNode = child
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
|
||||
newChild := &node{endpoint: path[pathNextIdx]}
|
||||
curNode.children = append(curNode.children, newChild)
|
||||
curNode = newChild
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -124,63 +189,94 @@ func New() *Router {
|
||||
}
|
||||
|
||||
func (rr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if tree, ok := rr.tree[r.Method]; ok {
|
||||
if r.URL.Path[0] != '/' {
|
||||
panic("first element of path should be a slash (/) symbol")
|
||||
}
|
||||
tree, ok := rr.tree[r.Method]
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
path, _ := newPath(r.URL.Path)
|
||||
path, err := newServePath(r.URL.Path)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotAcceptable)
|
||||
return
|
||||
}
|
||||
|
||||
if handler, params := tree.get(path, 0); handler != nil {
|
||||
if params != nil {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params))
|
||||
}
|
||||
handler(w, r)
|
||||
} else {
|
||||
if rr.NotFoundHandler != nil {
|
||||
rr.NotFoundHandler(w, r)
|
||||
} else {
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
}
|
||||
if handler, params := tree.get(path); handler != nil {
|
||||
if params != nil {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ParamsKey, params))
|
||||
}
|
||||
handler(w, r)
|
||||
} else {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
if rr.NotFoundHandler != nil {
|
||||
rr.NotFoundHandler(w, r)
|
||||
} else {
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handler registers a handler for provided pattern for a given HTTP method.
|
||||
func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) {
|
||||
if pattern[0] != '/' {
|
||||
panic("first element of path should be a slash (/) symbol")
|
||||
}
|
||||
|
||||
if strings.Count(pattern, "*") > 1 {
|
||||
panic("there can be only one wildcard (*) symbol in path")
|
||||
// Pattern must start with a slash (/) symbol.
|
||||
func (rr *Router) Handler(method, pattern string, handler http.HandlerFunc) error {
|
||||
path, err := newPath(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rr.tree[method] == nil {
|
||||
rr.tree[method] = &node{endpoint: "/"}
|
||||
}
|
||||
|
||||
path, _ := newPath(pattern)
|
||||
|
||||
rr.tree[method].add(path, 0, handler)
|
||||
return rr.tree[method].add(path, handler)
|
||||
}
|
||||
|
||||
// ServeStatic serves a given file system.
|
||||
//
|
||||
// Path should end with /*filepath to work.
|
||||
func (rr *Router) ServeStatic(path string, root http.FileSystem) {
|
||||
func (rr *Router) ServeStatic(path string, root http.FileSystem) error {
|
||||
fileServer := http.FileServer(root)
|
||||
|
||||
rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) {
|
||||
return rr.Handler(http.MethodGet, path, func(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Path = Param(r, "filepath")
|
||||
fileServer.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Param returns a parameter (that is set like `/a/b/:key/d`) inside a path
|
||||
// with a key or empty string if no such parameter found.
|
||||
// subPath attaches a base path in front of a pattern.
|
||||
//
|
||||
// It is not a sub-router, it just passes a resulted pattern down to
|
||||
// a router instance.
|
||||
type subPath struct {
|
||||
r *Router
|
||||
base string
|
||||
}
|
||||
|
||||
// Sub creates a group of handlers with the same base path.
|
||||
//
|
||||
// How to use:
|
||||
//
|
||||
// r := httpr.New()
|
||||
// ...
|
||||
// s := r.Sub("/api/v1")
|
||||
// s.Handler(http.MethodGet, "/", func(w, r) {...})
|
||||
// s.Handler(http.MethodGet, "/section", func(w, r) {...})
|
||||
func (rr *Router) Sub(base string) *subPath {
|
||||
if base[len(base)-1] == '/' {
|
||||
base = base[:len(base)-1]
|
||||
}
|
||||
|
||||
return &subPath{
|
||||
r: rr,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler registers a handler for a sub-path.
|
||||
func (sp *subPath) Handler(method, pattern string, handler http.HandlerFunc) error {
|
||||
return sp.r.Handler(method, sp.base+pattern, handler)
|
||||
}
|
||||
|
||||
// Param returns a URL parameter set with :key, or an empty string if not found.
|
||||
func Param(r *http.Request, key string) string {
|
||||
if params := r.Context().Value(ParamsKey).(Params); params != nil {
|
||||
return params[key]
|
||||
|
156
httpr_test.go
Normal file
156
httpr_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package httpr
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
r := New()
|
||||
|
||||
err := r.Handler(http.MethodGet, "/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.Handler(http.MethodGet, "/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
if err == nil {
|
||||
t.Fatal("path redefinition wasn't catched")
|
||||
}
|
||||
|
||||
err = r.Handler(http.MethodGet, "/a/b", func(w http.ResponseWriter, r *http.Request) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.Handler(http.MethodGet, "/:a/:b", func(w http.ResponseWriter, r *http.Request) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.Handler(http.MethodGet, "/:a/:lol", func(w http.ResponseWriter, r *http.Request) {})
|
||||
if err == nil {
|
||||
t.Fatal("here is a different last param name is supplied, should be catched")
|
||||
}
|
||||
|
||||
err = r.ServeStatic("/assets/*filepath", http.FS(os.DirFS(".")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.ServeStatic("/assets/*filepath/*filepath", nil)
|
||||
if err == nil {
|
||||
t.Fatal("multiple catch-all params wasn't catched")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPaths(t *testing.T) {
|
||||
found := false
|
||||
|
||||
r := New()
|
||||
|
||||
err := r.Handler(http.MethodGet, "/:lel", func(w http.ResponseWriter, r *http.Request) { found = true })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { found = false }
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p := "/xmpp://me@arav.su"
|
||||
req := httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if found {
|
||||
t.Error("Path", p, "should return 404")
|
||||
}
|
||||
|
||||
p = "/lel"
|
||||
req = httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if !found {
|
||||
t.Error("Path", p, "should return 200")
|
||||
}
|
||||
|
||||
p = "/lel/lol"
|
||||
req = httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if found {
|
||||
t.Error("Path", p, "should return 404")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubPaths(t *testing.T) {
|
||||
found := true
|
||||
|
||||
r := New()
|
||||
|
||||
s := r.Sub("/api/v1")
|
||||
|
||||
err := s.Handler(http.MethodGet, "/", func(w http.ResponseWriter, r *http.Request) { found = true })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = s.Handler(http.MethodGet, "/test", func(w http.ResponseWriter, r *http.Request) { found = true })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { found = false }
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p := "/api/v1/"
|
||||
req := httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if !found {
|
||||
t.Error("Path", p, "should return 200")
|
||||
}
|
||||
|
||||
p = "/api/v1/test"
|
||||
req = httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if !found {
|
||||
t.Error("Path", p, "should return 200")
|
||||
}
|
||||
|
||||
p = "/api/v1/nonexistent"
|
||||
req = httptest.NewRequest(http.MethodGet, p, strings.NewReader(""))
|
||||
r.ServeHTTP(w, req)
|
||||
if found {
|
||||
t.Error(found, "Path", p, "should return 404")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathParsing(t *testing.T) {
|
||||
p, err := newPath("/api/v1/../.")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(p)
|
||||
}
|
||||
|
||||
const testStr = "/api/v1/foo/bar/baz/abc/def/fucc/b0y/of/a/local/dungeon/master/got/his/ass/fisted/for/free/and/he/was/absolutely/happy/with/that/"
|
||||
|
||||
func BenchmarkPatternPathParsing(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p, err := newPath(testStr)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Log(len(p))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServePathParsing(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p, err := newServePath(testStr)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Log(len(p))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user