1package sql
2
3import (
4 "context"
5 "errors"
6
7 "golang.org/x/crypto/bcrypt"
8 "gorm.io/gorm"
9
10 "git.sr.ht/~gabrielgio/img/pkg/database/repository"
11)
12
13type (
14 User struct {
15 gorm.Model
16 Username string
17 Name string
18 Password string
19 IsAdmin bool
20 Path string
21 }
22
23 Users []*User
24
25 UserRepository struct {
26 db *gorm.DB
27 }
28)
29
30var _ repository.UserRepository = &UserRepository{}
31
32var _ repository.AuthRepository = &UserRepository{}
33
34func NewUserRepository(db *gorm.DB) *UserRepository {
35 return &UserRepository{
36 db: db,
37 }
38}
39
40func (self *User) ToModel() *repository.User {
41 return &repository.User{
42 ID: self.Model.ID,
43 Name: self.Name,
44 Username: self.Username,
45 Path: self.Path,
46 IsAdmin: self.IsAdmin,
47 }
48}
49
50func (self Users) ToModel() (users []*repository.User) {
51 for _, user := range self {
52 users = append(users, user.ToModel())
53 }
54 return
55}
56
57// Testing function, will remove later
58// TODO: remove later
59func (self *UserRepository) EnsureAdmin(ctx context.Context) {
60 var exists bool
61 self.db.
62 WithContext(ctx).
63 Model(&User{}).
64 Select("count(*) > 0").
65 Where("username = ?", "admin").
66 Find(&exists)
67
68 if !exists {
69 hash, _ := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.MinCost)
70 self.db.Save(&User{
71 Username: "admin",
72 Path: "/",
73 IsAdmin: true,
74 Password: string(hash),
75 })
76 }
77}
78
79func (self *UserRepository) List(ctx context.Context) ([]*repository.User, error) {
80 users := Users{}
81 result := self.db.
82 WithContext(ctx).
83 Find(&users)
84
85 if result.Error != nil {
86 return nil, wrapError(result.Error)
87 }
88
89 return users.ToModel(), nil
90}
91
92func (self *UserRepository) Get(ctx context.Context, id uint) (*repository.User, error) {
93 var user = &repository.User{ID: id}
94 result := self.db.
95 WithContext(ctx).
96 First(user)
97
98 if result.Error != nil {
99 return nil, wrapError(result.Error)
100 }
101
102 return user, nil
103}
104
105func (self *UserRepository) GetIDByUsername(ctx context.Context, username string) (uint, error) {
106 userID := struct {
107 ID uint
108 }{}
109
110 result := self.db.
111 WithContext(ctx).
112 Model(&User{}).
113 Where("username = ?", username).
114 First(&userID)
115
116 if result.Error != nil {
117 return 0, wrapError(result.Error)
118 }
119
120 return userID.ID, nil
121}
122
123func (self *UserRepository) GetPassword(ctx context.Context, id uint) ([]byte, error) {
124 userPassword := struct {
125 Password []byte
126 }{}
127
128 result := self.db.
129 WithContext(ctx).
130 Model(&User{}).
131 Where("id = ?", id).
132 First(&userPassword)
133
134 if result.Error != nil {
135 return nil, wrapError(result.Error)
136 }
137
138 return userPassword.Password, nil
139}
140
141func (self *UserRepository) Create(ctx context.Context, createUser *repository.CreateUser) (uint, error) {
142 user := &User{
143 Username: createUser.Username,
144 Name: createUser.Name,
145 Path: createUser.Path,
146 IsAdmin: createUser.IsAdmin,
147 Password: string(createUser.Password),
148 }
149
150 result := self.db.
151 WithContext(ctx).
152 Create(user)
153 if result.Error != nil {
154 return 0, wrapError(result.Error)
155 }
156
157 return user.Model.ID, nil
158}
159
160func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error {
161 user := &User{
162 Model: gorm.Model{
163 ID: id,
164 },
165 Username: update.Username,
166 Name: update.Name,
167 IsAdmin: update.IsAdmin,
168 Path: update.Path,
169 }
170
171 result := self.db.
172 WithContext(ctx).
173 Omit("password").
174 Updates(user)
175 if result.Error != nil {
176 return wrapError(result.Error)
177 }
178
179 return nil
180}
181
182func (self *UserRepository) Delete(ctx context.Context, id uint) error {
183 user := &User{
184 Model: gorm.Model{
185 ID: id,
186 },
187 }
188
189 result := self.db.
190 WithContext(ctx).
191 Delete(user)
192 if result.Error != nil {
193 return wrapError(result.Error)
194 }
195 return nil
196}
197
198func (u *UserRepository) Any(ctx context.Context) (bool, error) {
199 var exists bool
200 result := u.db.
201 WithContext(ctx).
202 Model(&User{}).
203 Select("count(id) > 0").
204 Find(&exists)
205
206 if result.Error != nil {
207 return false, wrapError(result.Error)
208 }
209
210 return exists, nil
211}
212
213func (u *UserRepository) GetPathFromUserID(ctx context.Context, id uint) (string, error) {
214 var userPath string
215
216 result := u.db.
217 WithContext(ctx).
218 Model(&User{}).
219 Select("path").
220 Where("id = ?", id).
221 First(&userPath)
222
223 if result.Error != nil {
224 return "", wrapError(result.Error)
225 }
226
227 return userPath, nil
228}
229
230func (u *UserRepository) UpdatePassword(ctx context.Context, id uint, password []byte) error {
231 result := u.db.
232 WithContext(ctx).
233 Model(&User{}).
234 Where("id = ?", id).
235 Update("password", password)
236
237 return wrapError(result.Error)
238}
239
240func wrapError(err error) error {
241 if errors.Is(err, gorm.ErrRecordNotFound) {
242 return repository.ErrRecordNotFound
243 }
244 return err
245}