summaryrefslogtreecommitdiff
path: root/internal/auth/auth_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/auth/auth_test.go')
-rw-r--r--internal/auth/auth_test.go196
1 files changed, 196 insertions, 0 deletions
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 013a4aa..fbe582b 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -1,6 +1,7 @@
package auth
import (
+ "context"
"database/sql"
"errors"
"testing"
@@ -106,3 +107,198 @@ func TestCreateUser(t *testing.T) {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
+
+func TestUserCount(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ rows := sqlmock.NewRows([]string{"count"}).AddRow(5)
+ mock.ExpectQuery("SELECT COUNT").WillReturnRows(rows)
+
+ count, err := service.UserCount()
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if count != 5 {
+ t.Errorf("expected count 5, got %d", count)
+ }
+}
+
+func TestEnsureDefaultUser_NoUsers(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ // Expect count query (no users)
+ countRows := sqlmock.NewRows([]string{"count"}).AddRow(0)
+ mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows)
+
+ // Expect check if user exists
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("admin").
+ WillReturnError(sql.ErrNoRows)
+
+ // Expect insert
+ mock.ExpectExec("INSERT INTO users").
+ WithArgs("admin", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+
+ // Expect retrieve created user
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "admin", "hashedpassword", time.Now())
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?").
+ WithArgs(1).
+ WillReturnRows(rows)
+
+ err = service.EnsureDefaultUser("admin", "password")
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestEnsureDefaultUser_UsersExist(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ // Expect count query (users exist)
+ countRows := sqlmock.NewRows([]string{"count"}).AddRow(1)
+ mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows)
+
+ // Should not create user when users exist
+ err = service.EnsureDefaultUser("admin", "password")
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+
+ if err := mock.ExpectationsWereMet(); err != nil {
+ t.Errorf("there were unfulfilled expectations: %s", err)
+ }
+}
+
+func TestAuthenticate_WrongPassword(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ // Hash a different password
+ hash, _ := bcrypt.GenerateFromPassword([]byte("correctpassword"), bcrypt.DefaultCost)
+
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "testuser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("testuser").
+ WillReturnRows(rows)
+
+ // Try with wrong password
+ _, err = service.Authenticate("testuser", "wrongpassword")
+ if !errors.Is(err, ErrInvalidCredentials) {
+ t.Errorf("expected ErrInvalidCredentials, got %v", err)
+ }
+}
+
+func TestCreateUser_AlreadyExists(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ // User already exists
+ hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(1, "existinguser", string(hash), time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?").
+ WithArgs("existinguser").
+ WillReturnRows(rows)
+
+ _, err = service.CreateUser("existinguser", "password")
+ if !errors.Is(err, ErrUserExists) {
+ t.Errorf("expected ErrUserExists, got %v", err)
+ }
+}
+
+func TestGetUserByID(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}).
+ AddRow(42, "testuser", "hashedpassword", time.Now())
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?").
+ WithArgs(int64(42)).
+ WillReturnRows(rows)
+
+ user, err := service.GetUserByID(42)
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ if user.ID != 42 {
+ t.Errorf("expected user ID 42, got %d", user.ID)
+ }
+}
+
+func TestGetUserByID_NotFound(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer func() { _ = db.Close() }()
+
+ service := NewService(db)
+
+ mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?").
+ WithArgs(int64(999)).
+ WillReturnError(sql.ErrNoRows)
+
+ _, err = service.GetUserByID(999)
+ if err == nil {
+ t.Error("expected error for non-existent user")
+ }
+}
+
+func TestGetCSRFTokenFromContext(t *testing.T) {
+ // Test with token in context
+ ctx := context.WithValue(context.Background(), ContextKeyCSRF, "test-token")
+ token := GetCSRFTokenFromContext(ctx)
+ if token != "test-token" {
+ t.Errorf("expected 'test-token', got '%s'", token)
+ }
+
+ // Test without token in context
+ emptyCtx := context.Background()
+ emptyToken := GetCSRFTokenFromContext(emptyCtx)
+ if emptyToken != "" {
+ t.Errorf("expected empty string, got '%s'", emptyToken)
+ }
+}