cerrado @ 72495f4538215051540eb05c14db0ed16142e06e

  1package ext
  2
  3import (
  4	"compress/gzip"
  5	"compress/lzw"
  6	"errors"
  7	"io"
  8	"log/slog"
  9	"net/http"
 10	"strconv"
 11	"strings"
 12
 13	"git.gabrielgio.me/cerrado/pkg/u"
 14	"github.com/andybalholm/brotli"
 15	"github.com/klauspost/compress/zstd"
 16)
 17
 18var (
 19	invalidParamErr = errors.New("Invalid weighted param")
 20)
 21
 22type CompressionResponseWriter struct {
 23	innerWriter    http.ResponseWriter
 24	compressWriter io.Writer
 25}
 26
 27func Compress(next http.HandlerFunc) http.HandlerFunc {
 28	return func(w http.ResponseWriter, r *http.Request) {
 29		if accept, ok := r.Header["Accept-Encoding"]; ok {
 30			if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
 31				defer compress.Close()
 32				w.Header().Add("Content-Encoding", algo)
 33				w = &CompressionResponseWriter{
 34					innerWriter:    w,
 35					compressWriter: compress,
 36				}
 37			}
 38		}
 39		next(w, r)
 40	}
 41}
 42
 43func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
 44	c := GetCompression(header)
 45	switch c {
 46	case "br":
 47		return GetBrotliWriter(inner), c
 48	case "gzip":
 49		return GetGZIPWriter(inner), c
 50	case "compress":
 51		return GetLZWWriter(inner), c
 52	case "zstd":
 53		return GetZSTDWriter(inner), c
 54	default:
 55		return nil, ""
 56	}
 57
 58}
 59
 60func (c *CompressionResponseWriter) Header() http.Header {
 61	return c.innerWriter.Header()
 62}
 63func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
 64	return c.compressWriter.Write(b)
 65}
 66
 67func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
 68	c.innerWriter.WriteHeader(statusCode)
 69}
 70
 71func GetCompression(header string) string {
 72	c := "*"
 73	q := 0.0
 74
 75	if header == "" {
 76		return c
 77	}
 78
 79	// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
 80	for _, e := range strings.Split(header, ",") {
 81		ps := strings.Split(e, ";")
 82		if len(ps) == 2 {
 83			w, err := getWeighedValue(ps[1])
 84			if err != nil {
 85				slog.Error(
 86					"Error parsing weighting from Accept-Encoding",
 87					"error", err,
 88				)
 89				continue
 90			}
 91			// gettting weighting value
 92			if w > q {
 93				q = w
 94				c = strings.Trim(ps[0], " ")
 95			}
 96		} else {
 97			if 1 > q {
 98				q = 1
 99				c = strings.Trim(ps[0], " ")
100			}
101		}
102	}
103
104	return c
105}
106
107func GetGZIPWriter(w io.Writer) io.WriteCloser {
108	// error can be ignored here since it will only err when compression level
109	// is not valid
110	r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
111	return r
112}
113
114func GetBrotliWriter(w io.Writer) io.WriteCloser {
115	return brotli.NewWriterLevel(w, brotli.BestCompression)
116}
117
118func GetZSTDWriter(w io.Writer) io.WriteCloser {
119	// error can be ignored here since it will only opts are given
120	r, _ := zstd.NewWriter(w)
121	return r
122}
123
124func GetLZWWriter(w io.Writer) io.WriteCloser {
125	return lzw.NewWriter(w, lzw.LSB, 8)
126}
127
128func getWeighedValue(part string) (float64, error) {
129	ps := strings.SplitN(part, "=", 2)
130	if len(ps) != 2 {
131		return 0, invalidParamErr
132	}
133	if name := strings.TrimSpace(ps[0]); name == "q" {
134		w, err := strconv.ParseFloat(ps[1], 64)
135		if err != nil {
136			return 0, err
137		}
138		return w, nil
139	}
140
141	return 0, invalidParamErr
142}