lens @ 417041797674319485bea96cb38489670ff3b2ef

  1package sql
  2
  3import (
  4	"context"
  5	"errors"
  6
  7	"golang.org/x/crypto/bcrypt"
  8	"gorm.io/gorm"
  9
 10	"git.sr.ht/~gabrielgio/img/pkg/database/repository"
 11)
 12
 13type (
 14	User struct {
 15		gorm.Model
 16		Username string
 17		Name     string
 18		Password string
 19		IsAdmin  bool
 20		Path     string
 21	}
 22
 23	Users []*User
 24
 25	UserRepository struct {
 26		db *gorm.DB
 27	}
 28)
 29
 30var _ repository.UserRepository = &UserRepository{}
 31
 32var _ repository.AuthRepository = &UserRepository{}
 33
 34func NewUserRepository(db *gorm.DB) *UserRepository {
 35	return &UserRepository{
 36		db: db,
 37	}
 38}
 39
 40func (self *User) ToModel() *repository.User {
 41	return &repository.User{
 42		ID:       self.Model.ID,
 43		Name:     self.Name,
 44		Username: self.Username,
 45		Path:     self.Path,
 46		IsAdmin:  self.IsAdmin,
 47	}
 48}
 49
 50func (self Users) ToModel() (users []*repository.User) {
 51	for _, user := range self {
 52		users = append(users, user.ToModel())
 53	}
 54	return
 55}
 56
 57// Testing function, will remove later
 58// TODO: remove later
 59func (self *UserRepository) EnsureAdmin(ctx context.Context) {
 60	var exists bool
 61	self.db.
 62		WithContext(ctx).
 63		Model(&User{}).
 64		Select("count(*) > 0").
 65		Where("username = ?", "admin").
 66		Find(&exists)
 67
 68	if !exists {
 69		hash, _ := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.MinCost)
 70		self.db.Save(&User{
 71			Username: "admin",
 72			Path:     "/",
 73			IsAdmin:  true,
 74			Password: string(hash),
 75		})
 76	}
 77}
 78
 79func (self *UserRepository) List(ctx context.Context) ([]*repository.User, error) {
 80	users := Users{}
 81	result := self.db.
 82		WithContext(ctx).
 83		Find(&users)
 84
 85	if result.Error != nil {
 86		return nil, wrapError(result.Error)
 87	}
 88
 89	return users.ToModel(), nil
 90}
 91
 92func (self *UserRepository) Get(ctx context.Context, id uint) (*repository.User, error) {
 93	var user = &repository.User{ID: id}
 94	result := self.db.
 95		WithContext(ctx).
 96		First(user)
 97
 98	if result.Error != nil {
 99		return nil, wrapError(result.Error)
100	}
101
102	return user, nil
103}
104
105func (self *UserRepository) GetIDByUsername(ctx context.Context, username string) (uint, error) {
106	userID := struct {
107		ID uint
108	}{}
109
110	result := self.db.
111		WithContext(ctx).
112		Model(&User{}).
113		Where("username = ?", username).
114		First(&userID)
115
116	if result.Error != nil {
117		return 0, wrapError(result.Error)
118	}
119
120	return userID.ID, nil
121}
122
123func (self *UserRepository) GetPassword(ctx context.Context, id uint) ([]byte, error) {
124	userPassword := struct {
125		Password []byte
126	}{}
127
128	result := self.db.
129		WithContext(ctx).
130		Model(&User{}).
131		Where("id = ?", id).
132		First(&userPassword)
133
134	if result.Error != nil {
135		return nil, wrapError(result.Error)
136	}
137
138	return userPassword.Password, nil
139}
140
141func (self *UserRepository) Create(ctx context.Context, createUser *repository.CreateUser) (uint, error) {
142	user := &User{
143		Username: createUser.Username,
144		Name:     createUser.Name,
145		Path:     createUser.Path,
146		IsAdmin:  createUser.IsAdmin,
147		Password: string(createUser.Password),
148	}
149
150	result := self.db.
151		WithContext(ctx).
152		Create(user)
153	if result.Error != nil {
154		return 0, wrapError(result.Error)
155	}
156
157	return user.Model.ID, nil
158}
159
160func (self *UserRepository) Update(ctx context.Context, id uint, update *repository.UpdateUser) error {
161	result := self.db.
162		WithContext(ctx).
163		Model(&User{}).
164		Omit("password").
165		Where("id = ?", id).
166		Update("username", update.Username).
167		Update("name", update.Name).
168		Update("is_admin", update.IsAdmin).
169		Update("path", update.Path)
170
171	if result.Error != nil {
172		return wrapError(result.Error)
173	}
174
175	return nil
176}
177
178func (self *UserRepository) Delete(ctx context.Context, id uint) error {
179	user := &User{
180		Model: gorm.Model{
181			ID: id,
182		},
183	}
184
185	result := self.db.
186		WithContext(ctx).
187		Delete(user)
188	if result.Error != nil {
189		return wrapError(result.Error)
190	}
191	return nil
192}
193
194func (u *UserRepository) Any(ctx context.Context) (bool, error) {
195	var exists bool
196	result := u.db.
197		WithContext(ctx).
198		Model(&User{}).
199		Select("count(id) > 0").
200		Find(&exists)
201
202	if result.Error != nil {
203		return false, wrapError(result.Error)
204	}
205
206	return exists, nil
207}
208
209func (u *UserRepository) GetPathFromUserID(ctx context.Context, id uint) (string, error) {
210	var userPath string
211
212	result := u.db.
213		WithContext(ctx).
214		Model(&User{}).
215		Select("path").
216		Where("id = ?", id).
217		First(&userPath)
218
219	if result.Error != nil {
220		return "", wrapError(result.Error)
221	}
222
223	return userPath, nil
224}
225
226func (u *UserRepository) UpdatePassword(ctx context.Context, id uint, password []byte) error {
227	result := u.db.
228		WithContext(ctx).
229		Model(&User{}).
230		Where("id = ?", id).
231		Update("password", password)
232
233	return wrapError(result.Error)
234}
235
236func wrapError(err error) error {
237	if errors.Is(err, gorm.ErrRecordNotFound) {
238		return repository.ErrRecordNotFound
239	}
240	return err
241}