1package sql
2
3import (
4 "context"
5 "errors"
6
7 "golang.org/x/crypto/bcrypt"
8 "gorm.io/gorm"
9
10 "git.sr.ht/~gabrielgio/img/pkg/database/repository"
11)
12
13type (
14 User struct {
15 gorm.Model
16 Username string
17 Name string
18 Password string
19 IsAdmin bool
20 Path string
21 }
22
23 Users []*User
24
25 UserRepository struct {
26 db *gorm.DB
27 }
28)
29
30var _ repository.UserRepository = &UserRepository{}
31
32var _ repository.AuthRepository = &UserRepository{}
33
34func NewUserRepository(db *gorm.DB) *UserRepository {
35 return &UserRepository{
36 db: db,
37 }
38}
39
40func (self *User) ToModel() *repository.User {
41 return &repository.User{
42 ID: self.Model.ID,
43 Name: self.Name,
44 Username: self.Username,
45 Path: self.Path,
46 IsAdmin: self.IsAdmin,
47 }
48}
49
50func (self Users) ToModel() (users []*repository.User) {
51 for _, user := range self {
52 users = append(users, user.ToModel())
53 }
54 return
55}
56
57// Testing function, will remove later
58// TODO: remove later
59func (self *UserRepository) EnsureAdmin(ctx context.Context) {
60 var exists bool
61 self.db.
62 WithContext(ctx).
63 Model(&User{}).
64 Select("count(*) > 0").
65 Where("username = ?", "admin").
66 Find(&exists)
67
68 if !exists {
69 hash, _ := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.MinCost)
70 self.db.Save(&User{
71 Username: "admin",
72 Path: "/",
73 IsAdmin: true,
74 Password: string(hash),
75 })
76 }
77}
78
79func (self *UserRepository) List(ctx context.Context) ([]*repository.User, error) {
80 users := Users{}
81 result := self.db.
82 WithContext(ctx).
83 Find(&users)
84
85 if result.Error != nil {
86 return nil, wrapError(result.Error)
87 }
88
89 return users.ToModel(), nil
90}
91
92func (self *UserRepository) Get(ctx context.Context, id uint) (*repository.User, error) {
93 var user = &repository.User{ID: id}
94 result := self.db.
95 WithContext(ctx).
96 First(user)
97
98 if result.Error != nil {
99 return nil, wrapError(result.Error)
100 }
101
102 return user, nil
103}
104
105func (self *UserRepository) GetIDByUsername(ctx context.Context, username string) (uint, error) {
106 userID := struct {
107 ID uint
108 }{}
109
110 result := self.db.
111 WithContext(ctx).
112 Model(&User{}).
113 Where("username = ?", username).
114 First(&userID)
115
116 if result.Error != nil {
117 return 0, wrapError(result.Error)
118 }
119
120 return userID.ID, nil
121}
122
123func (self *UserRepository) GetPassword(ctx context.Context, id uint) ([]byte, error) {
124 userPassword := struct {
125 Password []byte
126 }{}
127
128 result := self.db.
129 WithContext(ctx).
130 Model(&User{}).
131 Where("id = ?", id).
132 First(&userPassword)
133
134 if result.Error != nil {
135 return nil, wrapError(result.Error)
136 }
137
138 return userPassword.Password, nil
139}
140
141func (self *UserRepository) Create(ctx context.Context, createUser *repository.CreateUser) (uint, error) {
142 user := &User{
143 Username: createUser.Username,
144 Name: createUser.Name,
145 Path: createUser.Path,
146 IsAdmin: createUser.IsAdmin,
147 Password: string(createUser.Password),
148 }
149
150 result := self.db.
151 WithContext(ctx).
152 Create(user)
153 if result.Error != nil {
154 return 0, wrapError(result.Error)
155 }
156
157 return user.Model.ID, nil
158}
159
160func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error {
161 result := self.db.
162 WithContext(ctx).
163 Model(&User{}).
164 Omit("password").
165 Where("id = ?", id).
166 Update("username", update.Username).
167 Update("name", update.Name).
168 Update("is_admin", update.IsAdmin).
169 Update("path", update.Path)
170
171 if result.Error != nil {
172 return wrapError(result.Error)
173 }
174
175 return nil
176}
177
178func (self *UserRepository) Delete(ctx context.Context, id uint) error {
179 user := &User{
180 Model: gorm.Model{
181 ID: id,
182 },
183 }
184
185 result := self.db.
186 WithContext(ctx).
187 Delete(user)
188 if result.Error != nil {
189 return wrapError(result.Error)
190 }
191 return nil
192}
193
194func (u *UserRepository) Any(ctx context.Context) (bool, error) {
195 var exists bool
196 result := u.db.
197 WithContext(ctx).
198 Model(&User{}).
199 Select("count(id) > 0").
200 Find(&exists)
201
202 if result.Error != nil {
203 return false, wrapError(result.Error)
204 }
205
206 return exists, nil
207}
208
209func (u *UserRepository) GetPathFromUserID(ctx context.Context, id uint) (string, error) {
210 var userPath string
211
212 result := u.db.
213 WithContext(ctx).
214 Model(&User{}).
215 Select("path").
216 Where("id = ?", id).
217 First(&userPath)
218
219 if result.Error != nil {
220 return "", wrapError(result.Error)
221 }
222
223 return userPath, nil
224}
225
226func (u *UserRepository) UpdatePassword(ctx context.Context, id uint, password []byte) error {
227 result := u.db.
228 WithContext(ctx).
229 Model(&User{}).
230 Where("id = ?", id).
231 Update("password", password)
232
233 return wrapError(result.Error)
234}
235
236func wrapError(err error) error {
237 if errors.Is(err, gorm.ErrRecordNotFound) {
238 return repository.ErrRecordNotFound
239 }
240 return err
241}