cerrado @ ba84c0d82066739adbca468846a2688e02432b6f

  1package ext
  2
  3import (
  4	"compress/gzip"
  5	"compress/lzw"
  6	"errors"
  7	"io"
  8	"log/slog"
  9	"net/http"
 10	"strconv"
 11	"strings"
 12
 13	"git.gabrielgio.me/cerrado/pkg/u"
 14	"github.com/andybalholm/brotli"
 15	"github.com/klauspost/compress/zstd"
 16)
 17
 18var (
 19	invalidParamErr = errors.New("Invalid weighted param")
 20)
 21
 22type CompressionResponseWriter struct {
 23	innerWriter    http.ResponseWriter
 24	compressWriter io.Writer
 25}
 26
 27func Compress(next http.HandlerFunc) http.HandlerFunc {
 28	return func(w http.ResponseWriter, r *http.Request) {
 29
 30		// TODO: hand this better
 31		if strings.HasSuffix(r.URL.Path, ".tar.gz") {
 32			next(w, r)
 33			return
 34		}
 35
 36		if accept, ok := r.Header["Accept-Encoding"]; ok {
 37			if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
 38				defer compress.Close()
 39				w.Header().Add("Content-Encoding", algo)
 40				w = &CompressionResponseWriter{
 41					innerWriter:    w,
 42					compressWriter: compress,
 43				}
 44			}
 45		}
 46		next(w, r)
 47	}
 48}
 49
 50func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
 51	c := GetCompression(header)
 52	switch c {
 53	case "br":
 54		return GetBrotliWriter(inner), c
 55	case "gzip":
 56		return GetGZIPWriter(inner), c
 57	case "compress":
 58		return GetLZWWriter(inner), c
 59	case "zstd":
 60		return GetZSTDWriter(inner), c
 61	default:
 62		return nil, ""
 63	}
 64
 65}
 66
 67func (c *CompressionResponseWriter) Header() http.Header {
 68	return c.innerWriter.Header()
 69}
 70func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
 71	return c.compressWriter.Write(b)
 72}
 73
 74func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
 75	c.innerWriter.WriteHeader(statusCode)
 76}
 77
 78func GetCompression(header string) string {
 79	c := "*"
 80	q := 0.0
 81
 82	if header == "" {
 83		return c
 84	}
 85
 86	// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
 87	for _, e := range strings.Split(header, ",") {
 88		ps := strings.Split(e, ";")
 89		if len(ps) == 2 {
 90			w, err := getWeighedValue(ps[1])
 91			if err != nil {
 92				slog.Error(
 93					"Error parsing weighting from Accept-Encoding",
 94					"error", err,
 95				)
 96				continue
 97			}
 98			// gettting weighting value
 99			if w > q {
100				q = w
101				c = strings.Trim(ps[0], " ")
102			}
103		} else {
104			if 1 > q {
105				q = 1
106				c = strings.Trim(ps[0], " ")
107			}
108		}
109	}
110
111	return c
112}
113
114func GetGZIPWriter(w io.Writer) io.WriteCloser {
115	// error can be ignored here since it will only err when compression level
116	// is not valid
117	r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
118	return r
119}
120
121func GetBrotliWriter(w io.Writer) io.WriteCloser {
122	return brotli.NewWriterLevel(w, brotli.BestCompression)
123}
124
125func GetZSTDWriter(w io.Writer) io.WriteCloser {
126	// error can be ignored here since it will only opts are given
127	r, _ := zstd.NewWriter(w)
128	return r
129}
130
131func GetLZWWriter(w io.Writer) io.WriteCloser {
132	return lzw.NewWriter(w, lzw.LSB, 8)
133}
134
135func getWeighedValue(part string) (float64, error) {
136	ps := strings.SplitN(part, "=", 2)
137	if len(ps) != 2 {
138		return 0, invalidParamErr
139	}
140	if name := strings.TrimSpace(ps[0]); name == "q" {
141		w, err := strconv.ParseFloat(ps[1], 64)
142		if err != nil {
143			return 0, err
144		}
145		return w, nil
146	}
147
148	return 0, invalidParamErr
149}