1package ext
2
3import (
4 "compress/gzip"
5 "compress/lzw"
6 "errors"
7 "io"
8 "log/slog"
9 "net/http"
10 "strconv"
11 "strings"
12
13 "git.gabrielgio.me/cerrado/pkg/u"
14 "github.com/andybalholm/brotli"
15 "github.com/klauspost/compress/zstd"
16)
17
18var (
19 invalidParamErr = errors.New("Invalid weighted param")
20)
21
22type CompressionResponseWriter struct {
23 innerWriter http.ResponseWriter
24 compressWriter io.Writer
25}
26
27func Compress(next http.HandlerFunc) http.HandlerFunc {
28 return func(w http.ResponseWriter, r *http.Request) {
29 if accept, ok := r.Header["Accept-Encoding"]; ok {
30 if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
31 defer compress.Close()
32 w.Header().Add("Content-Encoding", algo)
33 w = &CompressionResponseWriter{
34 innerWriter: w,
35 compressWriter: compress,
36 }
37 }
38 }
39 next(w, r)
40 }
41}
42
43func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
44 c := GetCompression(header)
45 switch c {
46 case "br":
47 return GetBrotliWriter(inner), c
48 case "gzip":
49 return GetGZIPWriter(inner), c
50 case "compress":
51 return GetLZWWriter(inner), c
52 case "zstd":
53 return GetZSTDWriter(inner), c
54 default:
55 return nil, ""
56 }
57
58}
59
60func (c *CompressionResponseWriter) Header() http.Header {
61 return c.innerWriter.Header()
62}
63func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
64 return c.compressWriter.Write(b)
65}
66
67func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
68 c.innerWriter.WriteHeader(statusCode)
69}
70
71func GetCompression(header string) string {
72 c := "*"
73 q := 0.0
74
75 if header == "" {
76 return c
77 }
78
79 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
80 for _, e := range strings.Split(header, ",") {
81 ps := strings.Split(e, ";")
82 if len(ps) == 2 {
83 w, err := getWeighedValue(ps[1])
84 if err != nil {
85 slog.Error(
86 "Error parsing weighting from Accept-Encoding",
87 "error", err,
88 )
89 continue
90 }
91 // gettting weighting value
92 if w > q {
93 q = w
94 c = strings.Trim(ps[0], " ")
95 }
96 } else {
97 if 1 > q {
98 q = 1
99 c = strings.Trim(ps[0], " ")
100 }
101 }
102 }
103
104 return c
105}
106
107func GetGZIPWriter(w io.Writer) io.WriteCloser {
108 // error can be ignored here since it will only err when compression level
109 // is not valid
110 r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
111 return r
112}
113
114func GetBrotliWriter(w io.Writer) io.WriteCloser {
115 return brotli.NewWriterLevel(w, brotli.BestCompression)
116}
117
118func GetZSTDWriter(w io.Writer) io.WriteCloser {
119 // error can be ignored here since it will only opts are given
120 r, _ := zstd.NewWriter(w)
121 return r
122}
123
124func GetLZWWriter(w io.Writer) io.WriteCloser {
125 return lzw.NewWriter(w, lzw.LSB, 8)
126}
127
128func getWeighedValue(part string) (float64, error) {
129 ps := strings.SplitN(part, "=", 2)
130 if len(ps) != 2 {
131 return 0, invalidParamErr
132 }
133 if name := strings.TrimSpace(ps[0]); name == "q" {
134 w, err := strconv.ParseFloat(ps[1], 64)
135 if err != nil {
136 return 0, err
137 }
138 return w, nil
139 }
140
141 return 0, invalidParamErr
142}