package auth import ( "context" "database/sql" "errors" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "golang.org/x/crypto/bcrypt" ) func TestAuthenticate(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) password := "secret" hash, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(1, "testuser", string(hash), time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("testuser"). WillReturnRows(rows) user, err := service.Authenticate("testuser", password) if err != nil { t.Errorf("expected no error, got %v", err) } if user == nil { t.Fatal("expected user, got nil") } if user.Username != "testuser" { t.Errorf("expected username testuser, got %s", user.Username) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestAuthenticate_InvalidCredentials(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("nonexistent"). WillReturnError(sql.ErrNoRows) _, err = service.Authenticate("nonexistent", "password") if !errors.Is(err, ErrInvalidCredentials) { t.Errorf("expected ErrInvalidCredentials, got %v", err) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestCreateUser(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) // Expect check if user exists mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("newuser"). WillReturnError(sql.ErrNoRows) // Expect insert mock.ExpectExec("INSERT INTO users"). WithArgs("newuser", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) // Expect retrieve created user rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(1, "newuser", "hashedpassword", time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). WithArgs(1). WillReturnRows(rows) user, err := service.CreateUser("newuser", "password") if err != nil { t.Errorf("expected no error, got %v", err) } if user.Username != "newuser" { t.Errorf("expected username newuser, got %s", user.Username) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestUserCount(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) rows := sqlmock.NewRows([]string{"count"}).AddRow(5) mock.ExpectQuery("SELECT COUNT").WillReturnRows(rows) count, err := service.UserCount() if err != nil { t.Errorf("expected no error, got %v", err) } if count != 5 { t.Errorf("expected count 5, got %d", count) } } func TestEnsureDefaultUser_NoUsers(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) // Expect count query (no users) countRows := sqlmock.NewRows([]string{"count"}).AddRow(0) mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows) // Expect check if user exists mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("admin"). WillReturnError(sql.ErrNoRows) // Expect insert mock.ExpectExec("INSERT INTO users"). WithArgs("admin", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) // Expect retrieve created user rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(1, "admin", "hashedpassword", time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). WithArgs(1). WillReturnRows(rows) err = service.EnsureDefaultUser("admin", "password") if err != nil { t.Errorf("expected no error, got %v", err) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestEnsureDefaultUser_UsersExist(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) // Expect count query (users exist) countRows := sqlmock.NewRows([]string{"count"}).AddRow(1) mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows) // Should not create user when users exist err = service.EnsureDefaultUser("admin", "password") if err != nil { t.Errorf("expected no error, got %v", err) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestAuthenticate_WrongPassword(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) // Hash a different password hash, _ := bcrypt.GenerateFromPassword([]byte("correctpassword"), bcrypt.DefaultCost) rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(1, "testuser", string(hash), time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("testuser"). WillReturnRows(rows) // Try with wrong password _, err = service.Authenticate("testuser", "wrongpassword") if !errors.Is(err, ErrInvalidCredentials) { t.Errorf("expected ErrInvalidCredentials, got %v", err) } } func TestCreateUser_AlreadyExists(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) // User already exists hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(1, "existinguser", string(hash), time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). WithArgs("existinguser"). WillReturnRows(rows) _, err = service.CreateUser("existinguser", "password") if !errors.Is(err, ErrUserExists) { t.Errorf("expected ErrUserExists, got %v", err) } } func TestGetUserByID(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). AddRow(42, "testuser", "hashedpassword", time.Now()) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). WithArgs(int64(42)). WillReturnRows(rows) user, err := service.GetUserByID(42) if err != nil { t.Errorf("expected no error, got %v", err) } if user.ID != 42 { t.Errorf("expected user ID 42, got %d", user.ID) } } func TestGetUserByID_NotFound(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer func() { _ = db.Close() }() service := NewService(db) mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). WithArgs(int64(999)). WillReturnError(sql.ErrNoRows) _, err = service.GetUserByID(999) if err == nil { t.Error("expected error for non-existent user") } } func TestGetCSRFTokenFromContext(t *testing.T) { // Test with token in context ctx := context.WithValue(context.Background(), ContextKeyCSRF, "test-token") token := GetCSRFTokenFromContext(ctx) if token != "test-token" { t.Errorf("expected 'test-token', got '%s'", token) } // Test without token in context emptyCtx := context.Background() emptyToken := GetCSRFTokenFromContext(emptyCtx) if emptyToken != "" { t.Errorf("expected empty string, got '%s'", emptyToken) } }