1package ext
2
3import (
4 "context"
5 "encoding/base64"
6 "errors"
7 "net/http"
8 "time"
9
10 "github.com/sirupsen/logrus"
11
12 "git.sr.ht/~gabrielgio/img/pkg/database/repository"
13 "git.sr.ht/~gabrielgio/img/pkg/service"
14)
15
16func HTML(next http.HandlerFunc) http.HandlerFunc {
17 return func(w http.ResponseWriter, r *http.Request) {
18 w.Header().Set("Content-Type", "text/html")
19 next(w, r)
20 }
21}
22
23type (
24 User string
25
26 LogMiddleware struct {
27 entry *logrus.Entry
28 }
29)
30
31const (
32 UserKey User = "user"
33)
34
35func NewLogMiddleare(log *logrus.Entry) *LogMiddleware {
36 return &LogMiddleware{
37 entry: log,
38 }
39}
40
41func (l *LogMiddleware) HTTP(next http.HandlerFunc) http.HandlerFunc {
42 return func(w http.ResponseWriter, r *http.Request) {
43 start := time.Now()
44 next(w, r)
45 elapsed := time.Since(start)
46 l.entry.
47 WithField("time", elapsed).
48 WithField("path", r.URL.Path).
49 Info(r.Method)
50 }
51}
52
53type AuthMiddleware struct {
54 key []byte
55 entry *logrus.Entry
56 userRepository repository.UserRepository
57}
58
59func NewAuthMiddleware(
60 key []byte,
61 log *logrus.Entry,
62 userRepository repository.UserRepository,
63) *AuthMiddleware {
64 return &AuthMiddleware{
65 key: key,
66 entry: log.WithField("context", "auth"),
67 userRepository: userRepository,
68 }
69}
70
71func (a *AuthMiddleware) LoggedIn(next http.HandlerFunc) http.HandlerFunc {
72 return func(w http.ResponseWriter, r *http.Request) {
73 path := r.URL.Path
74 if path == "/login" || path == "/initial" {
75 next(w, r)
76 return
77 }
78
79 redirectLogin := "/login?redirect=" + path
80 authBase64, err := r.Cookie("auth")
81 if errors.Is(err, http.ErrNoCookie) {
82 a.entry.Info("No auth provided")
83 http.Redirect(w, r, redirectLogin, http.StatusTemporaryRedirect)
84 return
85 }
86
87 auth, err := base64.StdEncoding.DecodeString(authBase64.Value)
88 if err != nil {
89 a.entry.Error(err)
90 return
91 }
92
93 token, err := service.ReadToken(auth, a.key)
94 if err != nil {
95 a.entry.Error(err)
96 http.Redirect(w, r, redirectLogin, http.StatusTemporaryRedirect)
97 return
98 }
99
100 user, err := a.userRepository.Get(r.Context(), token.UserID)
101 if err != nil {
102 a.entry.Error(err)
103 return
104 }
105
106 r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
107 a.entry.
108 WithField("userID", token.UserID).
109 WithField("username", token.Username).
110 Info("user recognized")
111 next(w, r)
112 }
113}
114
115func GetUserFromCtx(r *http.Request) *repository.User {
116 tokenValue := r.Context().Value(UserKey)
117 if token, ok := tokenValue.(*repository.User); ok {
118 return token
119 }
120 return nil
121}
122
123type InitialSetupMiddleware struct {
124 userRepository repository.UserRepository
125}
126
127func NewInitialSetupMiddleware(userRepository repository.UserRepository) *InitialSetupMiddleware {
128 return &InitialSetupMiddleware{
129 userRepository: userRepository,
130 }
131}
132
133func (i *InitialSetupMiddleware) Check(next http.HandlerFunc) http.HandlerFunc {
134 return func(w http.ResponseWriter, r *http.Request) {
135
136 // if user has been set to context it is logged in already
137 token := GetUserFromCtx(r)
138 if token != nil {
139 next(w, r)
140 return
141 }
142
143 path := r.URL.Path
144 if path == "/initial" {
145 next(w, r)
146 return
147 }
148
149 exists, err := i.userRepository.Any(r.Context())
150 if err != nil {
151 InternalServerError(w, err)
152 return
153 }
154
155 if exists {
156 next(w, r)
157 return
158 }
159 http.Redirect(w, r, "/initial", http.StatusTemporaryRedirect)
160 }
161}