1package ext
2
3import (
4 "compress/gzip"
5 "compress/lzw"
6 "errors"
7 "io"
8 "log/slog"
9 "net/http"
10 "strconv"
11 "strings"
12
13 "git.gabrielgio.me/cerrado/pkg/u"
14 "github.com/andybalholm/brotli"
15 "github.com/klauspost/compress/zstd"
16)
17
18var (
19 errInvalidParam = errors.New("Invalid weighted param")
20)
21
22type CompressionResponseWriter struct {
23 innerWriter http.ResponseWriter
24 compressWriter io.Writer
25}
26
27func Compress(next http.HandlerFunc) http.HandlerFunc {
28 return func(w http.ResponseWriter, r *http.Request) {
29
30 // TODO: hand this better
31 if strings.HasSuffix(r.URL.Path, ".tar.gz") {
32 next(w, r)
33 return
34 }
35
36 if accept, ok := r.Header["Accept-Encoding"]; ok {
37 if compress, algo := GetCompressionWriter(u.FirstOrZero(accept), w); algo != "" {
38 defer compress.Close()
39 w.Header().Add("Content-Encoding", algo)
40 w = &CompressionResponseWriter{
41 innerWriter: w,
42 compressWriter: compress,
43 }
44 }
45 }
46 next(w, r)
47 }
48}
49
50func GetCompressionWriter(header string, inner io.Writer) (io.WriteCloser, string) {
51 c := GetCompression(header)
52 switch c {
53 case "br":
54 return GetBrotliWriter(inner), c
55 case "gzip":
56 return GetGZIPWriter(inner), c
57 case "compress":
58 return GetLZWWriter(inner), c
59 case "zstd":
60 return GetZSTDWriter(inner), c
61 default:
62 return nil, ""
63 }
64
65}
66
67func (c *CompressionResponseWriter) Header() http.Header {
68 return c.innerWriter.Header()
69}
70func (c *CompressionResponseWriter) Write(b []byte) (int, error) {
71 return c.compressWriter.Write(b)
72}
73
74func (c *CompressionResponseWriter) WriteHeader(statusCode int) {
75 c.innerWriter.WriteHeader(statusCode)
76}
77
78func GetCompression(header string) string {
79 c := "*"
80 q := 0.0
81
82 if header == "" {
83 return c
84 }
85
86 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
87 for _, e := range strings.Split(header, ",") {
88 ps := strings.Split(e, ";")
89 if len(ps) == 2 {
90 w, err := getWeighedValue(ps[1])
91 if err != nil {
92 slog.Error(
93 "Error parsing weighting from Accept-Encoding",
94 "error", err,
95 )
96 continue
97 }
98 // gettting weighting value
99 if w > q {
100 q = w
101 c = strings.Trim(ps[0], " ")
102 }
103 } else {
104 if 1 > q {
105 q = 1
106 c = strings.Trim(ps[0], " ")
107 }
108 }
109 }
110
111 return c
112}
113
114func GetGZIPWriter(w io.Writer) io.WriteCloser {
115 // error can be ignored here since it will only err when compression level
116 // is not valid
117 r, _ := gzip.NewWriterLevel(w, gzip.BestCompression)
118 return r
119}
120
121func GetBrotliWriter(w io.Writer) io.WriteCloser {
122 return brotli.NewWriterLevel(w, brotli.BestCompression)
123}
124
125func GetZSTDWriter(w io.Writer) io.WriteCloser {
126 // error can be ignored here since it will only opts are given
127 r, _ := zstd.NewWriter(w)
128 return r
129}
130
131func GetLZWWriter(w io.Writer) io.WriteCloser {
132 return lzw.NewWriter(w, lzw.LSB, 8)
133}
134
135func getWeighedValue(part string) (float64, error) {
136 ps := strings.SplitN(part, "=", 2)
137 if len(ps) != 2 {
138 return 0, errInvalidParam
139 }
140 if name := strings.TrimSpace(ps[0]); name == "q" {
141 w, err := strconv.ParseFloat(ps[1], 64)
142 if err != nil {
143 return 0, err
144 }
145 return w, nil
146 }
147
148 return 0, errInvalidParam
149}