diff options
| -rw-r--r-- | go.mod | 3 | ||||
| -rw-r--r-- | handler.go | 53 | ||||
| -rw-r--r-- | handler_test.go | 75 | 
3 files changed, 131 insertions, 0 deletions
@@ -0,0 +1,3 @@ +module gziphandler + +go 1.15 diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..ab21180 --- /dev/null +++ b/handler.go @@ -0,0 +1,53 @@ +package gziphandler + +import ( +	"compress/gzip" +	"net/http" +	"strings" +) + +type gw struct { +	http.ResponseWriter +	w *gzip.Writer +} + +func (w *gw) Write(d []byte) (int, error) { +	return w.w.Write(d) +} + +func (w *gw) Close() error { +	return w.w.Close() +} + +func acceptsGzip(r *http.Request) bool { +	ae := r.Header.Get("Accept-Encoding") +	for _, e := range strings.Split(ae, ",") { +		vals := strings.Split(e, ";") +		if len(vals) < 1 { +			continue +		} +		if strings.TrimSpace(vals[0]) == "gzip" { +			return true +		} +	} +	return false +} + +// Handler returns an http.Handler that compresses the response data written +// by an existing handler h, using the compress/gzip.Writer with default +// compression level. +func Handler(h http.Handler) http.Handler { +	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +		if !acceptsGzip(r) { +			h.ServeHTTP(w, r) +			return +		} +		w.Header().Set("Content-Encoding", "gzip") +		gw := &gw{ +			ResponseWriter: w, +			w:              gzip.NewWriter(w), +		} +		h.ServeHTTP(gw, r) +		gw.Close() +	}) +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..f1a034e --- /dev/null +++ b/handler_test.go @@ -0,0 +1,75 @@ +package gziphandler + +import ( +	"compress/gzip" +	"io" +	"net" +	"net/http" +	"strings" +	"testing" +) + +type d struct { +	data string +	gzip bool +} + +func TestHandler(t *testing.T) { +	data := []d{ +		{"abc", true}, +		{"qweqweqwe", false}, +		{"lkajwdlajwdlkajwdlkjawl", true}, +	} +	http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { +		_, err := io.Copy(w, r.Body) +		if err != nil { +			t.Fatal(err) +		} + +	}) +	handler := Handler(http.DefaultServeMux) +	l, err := net.Listen("tcp", "127.0.0.1:8080") +	if err != nil { +		t.Fatal(err) +	} +	defer l.Close() +	go http.Serve(l, handler) +	for _, d := range data { +		doReq(t, d) +	} +} + +func doReq(t *testing.T, d d) { +	req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:8080/test", strings.NewReader(d.data)) +	if err != nil { +		t.Fatal(err) +	} +	if d.gzip { +		req.Header.Set("Accept-Encoding", "gzip") +	} +	resp, err := http.DefaultClient.Do(req) +	if err != nil { +		t.Fatal(err) +	} +	defer resp.Body.Close() +	var sb strings.Builder +	r := resp.Body +	if d.gzip { +		ce := resp.Header.Get("Content-Encoding") +		if ce != "gzip" { +			t.Errorf("expected content encoding %s, got %s", "gzip", ce) +		} +		r, err = gzip.NewReader(resp.Body) +		if err != nil { +			t.Fatal(err) +		} +	} +	_, err = io.Copy(&sb, r) +	if err != nil { +		t.Fatal(err) +	} +	s := sb.String() +	if d.data != s { +		t.Errorf("expected response %s, got %s", d.data, s) +	} +}  | 
