diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/http_test.go | 207 | ||||
| -rw-r--r-- | internal/api/todoist_test.go | 183 | ||||
| -rw-r--r-- | internal/api/trello_test.go | 164 | ||||
| -rw-r--r-- | internal/auth/auth_test.go | 196 | ||||
| -rw-r--r-- | internal/config/config_test.go | 289 | ||||
| -rw-r--r-- | internal/handlers/timeline_logic_test.go | 92 | ||||
| -rw-r--r-- | internal/middleware/security_test.go | 200 | ||||
| -rw-r--r-- | internal/models/atom_test.go | 291 | ||||
| -rw-r--r-- | internal/store/sqlite_test.go | 986 |
9 files changed, 2608 insertions, 0 deletions
diff --git a/internal/api/http_test.go b/internal/api/http_test.go new file mode 100644 index 0000000..c2c32ee --- /dev/null +++ b/internal/api/http_test.go @@ -0,0 +1,207 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewBaseClient(t *testing.T) { + client := NewBaseClient("https://api.example.com") + if client.BaseURL != "https://api.example.com" { + t.Errorf("Expected BaseURL 'https://api.example.com', got '%s'", client.BaseURL) + } + if client.HTTPClient == nil { + t.Error("HTTPClient should not be nil") + } +} + +func TestBaseClient_Get(t *testing.T) { + expected := map[string]string{"message": "hello"} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Expected GET, got %s", r.Method) + } + if r.URL.Path != "/test" { + t.Errorf("Expected path /test, got %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer token123" { + t.Errorf("Expected Authorization header") + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(expected) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]string + err := client.Get(context.Background(), "/test", map[string]string{"Authorization": "Bearer token123"}, &result) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if result["message"] != "hello" { + t.Errorf("Expected message 'hello', got '%s'", result["message"]) + } +} + +func TestBaseClient_Get_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("server error")) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]string + err := client.Get(context.Background(), "/test", nil, &result) + if err == nil { + t.Error("Expected error for 500 response") + } +} + +func TestBaseClient_Post(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Error("Expected Content-Type application/json") + } + + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "test" { + t.Errorf("Expected name 'test', got '%s'", body["name"]) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]int{"id": 42}) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]int + err := client.Post(context.Background(), "/create", nil, map[string]string{"name": "test"}, &result) + if err != nil { + t.Fatalf("Post failed: %v", err) + } + if result["id"] != 42 { + t.Errorf("Expected id 42, got %d", result["id"]) + } +} + +func TestBaseClient_Post_NilBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-Type") == "application/json" { + t.Error("Content-Type should not be set for nil body") + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]string + err := client.Post(context.Background(), "/action", nil, nil, &result) + if err != nil { + t.Fatalf("Post with nil body failed: %v", err) + } +} + +func TestBaseClient_PostForm(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + t.Error("Expected form-urlencoded Content-Type") + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]bool{"success": true}) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]bool + err := client.PostForm(context.Background(), "/submit", nil, "key=value", &result) + if err != nil { + t.Fatalf("PostForm failed: %v", err) + } + if !result["success"] { + t.Error("Expected success to be true") + } +} + +func TestBaseClient_PostEmpty(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.Header.Get("X-Custom") != "custom-value" { + t.Error("Expected custom header") + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + err := client.PostEmpty(context.Background(), "/trigger", map[string]string{"X-Custom": "custom-value"}) + if err != nil { + t.Fatalf("PostEmpty failed: %v", err) + } +} + +func TestBaseClient_PostEmpty_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("bad request")) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + err := client.PostEmpty(context.Background(), "/fail", nil) + if err == nil { + t.Error("Expected error for 400 response") + } +} + +func TestBaseClient_Put(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + t.Errorf("Expected PUT, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + t.Error("Expected form-urlencoded Content-Type") + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"updated": "yes"}) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]string + err := client.Put(context.Background(), "/update", nil, "field=value", &result) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + if result["updated"] != "yes" { + t.Error("Expected updated to be yes") + } +} + +func TestBaseClient_doJSON_DecodeError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("invalid json")) + })) + defer server.Close() + + client := BaseClient{HTTPClient: server.Client(), BaseURL: server.URL} + var result map[string]string + err := client.Get(context.Background(), "/bad-json", nil, &result) + if err == nil { + t.Error("Expected error for invalid JSON") + } +} diff --git a/internal/api/todoist_test.go b/internal/api/todoist_test.go index 7bbcc1e..f7ca719 100644 --- a/internal/api/todoist_test.go +++ b/internal/api/todoist_test.go @@ -246,3 +246,186 @@ func TestTodoistClient_GetProjects(t *testing.T) { t.Errorf("Project 2 mismatch: got ID=%s Name=%s", projects[1].ID, projects[1].Name) } } + +func TestTodoistClient_GetTasks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Expected GET, got %s", r.Method) + } + + w.Header().Set("Content-Type", "application/json") + + // GetTasks also calls GetProjects internally + if r.URL.Path == "/projects" { + response := []todoistProjectResponse{ + {ID: "proj-1", Name: "Project 1"}, + } + json.NewEncoder(w).Encode(response) + return + } + + if r.URL.Path == "/tasks" { + response := []todoistTaskResponse{ + {ID: "task-1", Content: "Task 1", ProjectID: "proj-1", CreatedAt: time.Now().Format(time.RFC3339)}, + {ID: "task-2", Content: "Task 2", ProjectID: "proj-1", CreatedAt: time.Now().Format(time.RFC3339)}, + } + json.NewEncoder(w).Encode(response) + return + } + + t.Errorf("Unexpected path: %s", r.URL.Path) + })) + defer server.Close() + + client := newTestTodoistClient(server.URL, "test-key") + tasks, err := client.GetTasks(context.Background()) + if err != nil { + t.Fatalf("GetTasks failed: %v", err) + } + + if len(tasks) != 2 { + t.Errorf("Expected 2 tasks, got %d", len(tasks)) + } +} + +func TestTodoistClient_ReopenTask(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + expectedPath := "/tasks/task-123/reopen" + if r.URL.Path != expectedPath { + t.Errorf("Expected path %s, got %s", expectedPath, r.URL.Path) + } + + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := newTestTodoistClient(server.URL, "test-key") + err := client.ReopenTask(context.Background(), "task-123") + if err != nil { + t.Fatalf("ReopenTask failed: %v", err) + } +} + +func TestTodoistClient_UpdateTask(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + expectedPath := "/tasks/task-123" + if r.URL.Path != expectedPath { + t.Errorf("Expected path %s, got %s", expectedPath, r.URL.Path) + } + + var payload map[string]interface{} + json.NewDecoder(r.Body).Decode(&payload) + if payload["content"] != "Updated Content" { + t.Errorf("Expected content 'Updated Content', got %v", payload["content"]) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"id": "task-123"}) + })) + defer server.Close() + + client := newTestTodoistClient(server.URL, "test-key") + err := client.UpdateTask(context.Background(), "task-123", map[string]interface{}{"content": "Updated Content"}) + if err != nil { + t.Fatalf("UpdateTask failed: %v", err) + } +} + +func TestTodoistClient_Sync(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.URL.Path != "/sync" { + t.Errorf("Expected path /sync, got %s", r.URL.Path) + } + + response := TodoistSyncResponse{ + SyncToken: "new-sync-token", + FullSync: true, + Items: []SyncItemResponse{ + {ID: "item-1", Content: "Item 1", ProjectID: "proj-1"}, + }, + Projects: []SyncProjectResponse{ + {ID: "proj-1", Name: "Project 1"}, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := newTestTodoistClient(server.URL, "test-key") + resp, err := client.Sync(context.Background(), "*") + if err != nil { + t.Fatalf("Sync failed: %v", err) + } + + if resp.SyncToken != "new-sync-token" { + t.Errorf("Expected sync token 'new-sync-token', got '%s'", resp.SyncToken) + } + if len(resp.Items) != 1 { + t.Errorf("Expected 1 item, got %d", len(resp.Items)) + } +} + +func TestConvertSyncItemsToTasks(t *testing.T) { + projects := map[string]string{ + "proj-1": "Project 1", + } + + items := []SyncItemResponse{ + { + ID: "item-1", + Content: "Task 1", + Description: "Description 1", + ProjectID: "proj-1", + Priority: 3, + Labels: []string{"label1"}, + }, + { + ID: "item-2", + Content: "Completed Task", + ProjectID: "proj-1", + IsCompleted: true, + }, + } + + tasks := ConvertSyncItemsToTasks(items, projects) + + // Should skip completed task + if len(tasks) != 1 { + t.Errorf("Expected 1 task (excluding completed), got %d", len(tasks)) + } + + if tasks[0].ID != "item-1" { + t.Errorf("Expected task ID 'item-1', got '%s'", tasks[0].ID) + } + if tasks[0].ProjectName != "Project 1" { + t.Errorf("Expected project name 'Project 1', got '%s'", tasks[0].ProjectName) + } +} + +func TestBuildProjectMapFromSync(t *testing.T) { + projects := []SyncProjectResponse{ + {ID: "proj-1", Name: "Project 1"}, + {ID: "proj-2", Name: "Project 2"}, + } + + projectMap := BuildProjectMapFromSync(projects) + + if len(projectMap) != 2 { + t.Errorf("Expected 2 projects in map, got %d", len(projectMap)) + } + + if projectMap["proj-1"] != "Project 1" { + t.Errorf("Expected 'Project 1', got '%s'", projectMap["proj-1"]) + } +} diff --git a/internal/api/trello_test.go b/internal/api/trello_test.go index d677363..a209d01 100644 --- a/internal/api/trello_test.go +++ b/internal/api/trello_test.go @@ -239,3 +239,167 @@ func parseFormData(data string) (map[string]string, error) { } return result, nil } + +func TestTrelloClient_GetBoards(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Expected GET, got %s", r.Method) + } + if r.URL.Path != "/members/me/boards" { + t.Errorf("Expected path /members/me/boards, got %s", r.URL.Path) + } + + response := []trelloBoardResponse{ + {ID: "board-1", Name: "Board 1"}, + {ID: "board-2", Name: "Board 2"}, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := newTestTrelloClient(server.URL, "test-key", "test-token") + boards, err := client.GetBoards(context.Background()) + if err != nil { + t.Fatalf("GetBoards failed: %v", err) + } + + if len(boards) != 2 { + t.Errorf("Expected 2 boards, got %d", len(boards)) + } + if boards[0].ID != "board-1" { + t.Errorf("Expected board ID 'board-1', got '%s'", boards[0].ID) + } +} + +func TestTrelloClient_GetCards(t *testing.T) { + dueDate := "2024-01-15T12:00:00Z" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Expected GET, got %s", r.Method) + } + + w.Header().Set("Content-Type", "application/json") + + // GetCards calls getLists internally + if strings.Contains(r.URL.Path, "/lists") { + response := []trelloListResponse{ + {ID: "list-1", Name: "To Do"}, + } + json.NewEncoder(w).Encode(response) + return + } + + if strings.Contains(r.URL.Path, "/cards") { + response := []trelloCardResponse{ + {ID: "card-1", Name: "Card 1", IDList: "list-1", IDBoard: "board-1", Due: &dueDate}, + {ID: "card-2", Name: "Card 2", IDList: "list-1", IDBoard: "board-1"}, + } + json.NewEncoder(w).Encode(response) + return + } + + t.Errorf("Unexpected path: %s", r.URL.Path) + })) + defer server.Close() + + client := newTestTrelloClient(server.URL, "test-key", "test-token") + cards, err := client.GetCards(context.Background(), "board-1") + if err != nil { + t.Fatalf("GetCards failed: %v", err) + } + + if len(cards) != 2 { + t.Errorf("Expected 2 cards, got %d", len(cards)) + } + if cards[0].DueDate == nil { + t.Error("Expected due date for card 1") + } +} + +func TestTrelloClient_GetLists(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Expected GET, got %s", r.Method) + } + if !strings.Contains(r.URL.Path, "/boards/board-1/lists") { + t.Errorf("Expected path to contain /boards/board-1/lists, got %s", r.URL.Path) + } + + response := []trelloListResponse{ + {ID: "list-1", Name: "To Do"}, + {ID: "list-2", Name: "Done"}, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := newTestTrelloClient(server.URL, "test-key", "test-token") + lists, err := client.GetLists(context.Background(), "board-1") + if err != nil { + t.Fatalf("GetLists failed: %v", err) + } + + if len(lists) != 2 { + t.Errorf("Expected 2 lists, got %d", len(lists)) + } + if lists[0].Name != "To Do" { + t.Errorf("Expected 'To Do', got '%s'", lists[0].Name) + } +} + +func TestTrelloClient_GetBoardsWithCards(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + if strings.Contains(r.URL.Path, "/members/me/boards") { + response := []trelloBoardResponse{ + {ID: "board-1", Name: "Board 1"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + + if strings.Contains(r.URL.Path, "/lists") { + response := []trelloListResponse{ + {ID: "list-1", Name: "To Do"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + + if strings.Contains(r.URL.Path, "/cards") { + response := []trelloCardResponse{ + {ID: "card-1", Name: "Card 1", IDList: "list-1", IDBoard: "board-1"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + + t.Errorf("Unexpected path: %s", r.URL.Path) + })) + defer server.Close() + + client := newTestTrelloClient(server.URL, "test-key", "test-token") + boards, err := client.GetBoardsWithCards(context.Background()) + if err != nil { + t.Fatalf("GetBoardsWithCards failed: %v", err) + } + + if len(boards) != 1 { + t.Errorf("Expected 1 board, got %d", len(boards)) + } + if len(boards[0].Cards) != 1 { + t.Errorf("Expected 1 card, got %d", len(boards[0].Cards)) + } + if boards[0].Cards[0].ListName != "To Do" { + t.Errorf("Expected list name 'To Do', got '%s'", boards[0].Cards[0].ListName) + } +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 013a4aa..fbe582b 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "database/sql" "errors" "testing" @@ -106,3 +107,198 @@ func TestCreateUser(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestUserCount(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + rows := sqlmock.NewRows([]string{"count"}).AddRow(5) + mock.ExpectQuery("SELECT COUNT").WillReturnRows(rows) + + count, err := service.UserCount() + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if count != 5 { + t.Errorf("expected count 5, got %d", count) + } +} + +func TestEnsureDefaultUser_NoUsers(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + // Expect count query (no users) + countRows := sqlmock.NewRows([]string{"count"}).AddRow(0) + mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows) + + // Expect check if user exists + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("admin"). + WillReturnError(sql.ErrNoRows) + + // Expect insert + mock.ExpectExec("INSERT INTO users"). + WithArgs("admin", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Expect retrieve created user + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "admin", "hashedpassword", time.Now()) + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). + WithArgs(1). + WillReturnRows(rows) + + err = service.EnsureDefaultUser("admin", "password") + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestEnsureDefaultUser_UsersExist(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + // Expect count query (users exist) + countRows := sqlmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT COUNT").WillReturnRows(countRows) + + // Should not create user when users exist + err = service.EnsureDefaultUser("admin", "password") + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestAuthenticate_WrongPassword(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + // Hash a different password + hash, _ := bcrypt.GenerateFromPassword([]byte("correctpassword"), bcrypt.DefaultCost) + + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "testuser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("testuser"). + WillReturnRows(rows) + + // Try with wrong password + _, err = service.Authenticate("testuser", "wrongpassword") + if !errors.Is(err, ErrInvalidCredentials) { + t.Errorf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestCreateUser_AlreadyExists(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + // User already exists + hash, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(1, "existinguser", string(hash), time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE username = ?"). + WithArgs("existinguser"). + WillReturnRows(rows) + + _, err = service.CreateUser("existinguser", "password") + if !errors.Is(err, ErrUserExists) { + t.Errorf("expected ErrUserExists, got %v", err) + } +} + +func TestGetUserByID(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + rows := sqlmock.NewRows([]string{"id", "username", "password_hash", "created_at"}). + AddRow(42, "testuser", "hashedpassword", time.Now()) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). + WithArgs(int64(42)). + WillReturnRows(rows) + + user, err := service.GetUserByID(42) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user.ID != 42 { + t.Errorf("expected user ID 42, got %d", user.ID) + } +} + +func TestGetUserByID_NotFound(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { _ = db.Close() }() + + service := NewService(db) + + mock.ExpectQuery("SELECT id, username, password_hash, created_at FROM users WHERE id = ?"). + WithArgs(int64(999)). + WillReturnError(sql.ErrNoRows) + + _, err = service.GetUserByID(999) + if err == nil { + t.Error("expected error for non-existent user") + } +} + +func TestGetCSRFTokenFromContext(t *testing.T) { + // Test with token in context + ctx := context.WithValue(context.Background(), ContextKeyCSRF, "test-token") + token := GetCSRFTokenFromContext(ctx) + if token != "test-token" { + t.Errorf("expected 'test-token', got '%s'", token) + } + + // Test without token in context + emptyCtx := context.Background() + emptyToken := GetCSRFTokenFromContext(emptyCtx) + if emptyToken != "" { + t.Errorf("expected empty string, got '%s'", emptyToken) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..41cd6e0 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,289 @@ +package config + +import ( + "os" + "testing" + "time" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + { + name: "valid config", + cfg: Config{ + TodoistAPIKey: "todoist-key", + TrelloAPIKey: "trello-key", + TrelloToken: "trello-token", + }, + wantErr: false, + }, + { + name: "missing todoist key", + cfg: Config{ + TrelloAPIKey: "trello-key", + TrelloToken: "trello-token", + }, + wantErr: true, + }, + { + name: "missing trello key", + cfg: Config{ + TodoistAPIKey: "todoist-key", + TrelloToken: "trello-token", + }, + wantErr: true, + }, + { + name: "missing trello token", + cfg: Config{ + TodoistAPIKey: "todoist-key", + TrelloAPIKey: "trello-key", + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.cfg.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} + +func TestConfigHasMethods(t *testing.T) { + cfg := Config{ + PlanToEatAPIKey: "pte-key", + TrelloAPIKey: "trello-key", + TrelloToken: "trello-token", + GoogleCredentialsFile: "/path/to/creds.json", + GoogleTasksListID: "@default", + } + + if !cfg.HasPlanToEat() { + t.Error("HasPlanToEat should return true") + } + + if !cfg.HasTrello() { + t.Error("HasTrello should return true") + } + + if !cfg.HasGoogleCalendar() { + t.Error("HasGoogleCalendar should return true") + } + + if !cfg.HasGoogleTasks() { + t.Error("HasGoogleTasks should return true") + } + + // Test with empty config + emptyCfg := Config{} + if emptyCfg.HasPlanToEat() { + t.Error("HasPlanToEat should return false for empty config") + } + if emptyCfg.HasTrello() { + t.Error("HasTrello should return false for empty config") + } + if emptyCfg.HasGoogleCalendar() { + t.Error("HasGoogleCalendar should return false for empty config") + } + if emptyCfg.HasGoogleTasks() { + t.Error("HasGoogleTasks should return false for empty config") + } + + // Test session check + sessionCfg := Config{PlanToEatSession: "session-cookie"} + if !sessionCfg.HasPlanToEat() { + t.Error("HasPlanToEat should return true for session") + } + if !sessionCfg.HasPlanToEatSession() { + t.Error("HasPlanToEatSession should return true") + } +} + +func TestGetEnvWithDefault(t *testing.T) { + // Test with set env var + os.Setenv("TEST_CONFIG_VAR", "test_value") + defer os.Unsetenv("TEST_CONFIG_VAR") + + if val := getEnvWithDefault("TEST_CONFIG_VAR", "default"); val != "test_value" { + t.Errorf("Expected 'test_value', got '%s'", val) + } + + // Test with unset env var + if val := getEnvWithDefault("UNSET_CONFIG_VAR", "default"); val != "default" { + t.Errorf("Expected 'default', got '%s'", val) + } +} + +func TestGetEnvAsInt(t *testing.T) { + // Test with valid int + os.Setenv("TEST_INT_VAR", "42") + defer os.Unsetenv("TEST_INT_VAR") + + if val := getEnvAsInt("TEST_INT_VAR", 10); val != 42 { + t.Errorf("Expected 42, got %d", val) + } + + // Test with invalid int + os.Setenv("TEST_INVALID_INT", "not_a_number") + defer os.Unsetenv("TEST_INVALID_INT") + + if val := getEnvAsInt("TEST_INVALID_INT", 10); val != 10 { + t.Errorf("Expected default 10 for invalid int, got %d", val) + } + + // Test with unset var + if val := getEnvAsInt("UNSET_INT_VAR", 10); val != 10 { + t.Errorf("Expected default 10, got %d", val) + } +} + +func TestGetEnvAsBool(t *testing.T) { + // Test with true values + os.Setenv("TEST_BOOL_TRUE", "true") + defer os.Unsetenv("TEST_BOOL_TRUE") + + if val := getEnvAsBool("TEST_BOOL_TRUE", false); !val { + t.Error("Expected true") + } + + // Test with false values + os.Setenv("TEST_BOOL_FALSE", "false") + defer os.Unsetenv("TEST_BOOL_FALSE") + + if val := getEnvAsBool("TEST_BOOL_FALSE", true); val { + t.Error("Expected false") + } + + // Test with invalid bool + os.Setenv("TEST_INVALID_BOOL", "maybe") + defer os.Unsetenv("TEST_INVALID_BOOL") + + if val := getEnvAsBool("TEST_INVALID_BOOL", true); !val { + t.Error("Expected default true for invalid bool") + } + + // Test with unset var + if val := getEnvAsBool("UNSET_BOOL_VAR", true); !val { + t.Error("Expected default true") + } +} + +// Timezone tests +func TestGetDisplayTimezone(t *testing.T) { + // Before SetDisplayTimezone is called, should return UTC + loc := GetDisplayTimezone() + if loc == nil { + t.Fatal("GetDisplayTimezone should not return nil") + } +} + +func TestNow(t *testing.T) { + now := Now() + // Just verify it returns a valid time + if now.IsZero() { + t.Error("Now() should not return zero time") + } +} + +func TestToday(t *testing.T) { + today := Today() + + // Today should have zero hours, minutes, seconds + if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 { + t.Error("Today() should return midnight") + } +} + +func TestParseDateInDisplayTZ(t *testing.T) { + parsed, err := ParseDateInDisplayTZ("2024-01-15") + if err != nil { + t.Fatalf("ParseDateInDisplayTZ failed: %v", err) + } + + if parsed.Year() != 2024 || parsed.Month() != time.January || parsed.Day() != 15 { + t.Errorf("Unexpected date: %v", parsed) + } + + // Test invalid date + _, err = ParseDateInDisplayTZ("invalid") + if err == nil { + t.Error("Expected error for invalid date") + } +} + +func TestParseDateTimeInDisplayTZ(t *testing.T) { + parsed, err := ParseDateTimeInDisplayTZ("2006-01-02 15:04", "2024-01-15 14:30") + if err != nil { + t.Fatalf("ParseDateTimeInDisplayTZ failed: %v", err) + } + + if parsed.Hour() != 14 || parsed.Minute() != 30 { + t.Errorf("Unexpected time: %v", parsed) + } +} + +func TestToDisplayTZ(t *testing.T) { + utcTime := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + converted := ToDisplayTZ(utcTime) + + // Just verify it doesn't panic and returns a valid time + if converted.IsZero() { + t.Error("ToDisplayTZ should not return zero time") + } +} + +func TestLoad(t *testing.T) { + // Set up required env vars + os.Setenv("TODOIST_API_KEY", "test-todoist-key") + os.Setenv("TRELLO_API_KEY", "test-trello-key") + os.Setenv("TRELLO_TOKEN", "test-trello-token") + os.Setenv("PORT", "9999") + os.Setenv("CACHE_TTL_MINUTES", "10") + os.Setenv("DEBUG", "true") + defer func() { + os.Unsetenv("TODOIST_API_KEY") + os.Unsetenv("TRELLO_API_KEY") + os.Unsetenv("TRELLO_TOKEN") + os.Unsetenv("PORT") + os.Unsetenv("CACHE_TTL_MINUTES") + os.Unsetenv("DEBUG") + }() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if cfg.TodoistAPIKey != "test-todoist-key" { + t.Errorf("Expected TodoistAPIKey 'test-todoist-key', got '%s'", cfg.TodoistAPIKey) + } + if cfg.Port != "9999" { + t.Errorf("Expected Port '9999', got '%s'", cfg.Port) + } + if cfg.CacheTTLMinutes != 10 { + t.Errorf("Expected CacheTTLMinutes 10, got %d", cfg.CacheTTLMinutes) + } + if !cfg.Debug { + t.Error("Expected Debug to be true") + } +} + +func TestLoad_ValidationError(t *testing.T) { + // Clear required env vars to trigger validation error + os.Unsetenv("TODOIST_API_KEY") + os.Unsetenv("TRELLO_API_KEY") + os.Unsetenv("TRELLO_TOKEN") + + _, err := Load() + if err == nil { + t.Error("Expected validation error when required env vars are missing") + } +} diff --git a/internal/handlers/timeline_logic_test.go b/internal/handlers/timeline_logic_test.go index 5fe995f..9a71741 100644 --- a/internal/handlers/timeline_logic_test.go +++ b/internal/handlers/timeline_logic_test.go @@ -162,3 +162,95 @@ func TestBuildTimeline(t *testing.T) { t.Errorf("Expected item 3 to be Card, got %s", items[3].Type) } } + +func TestCalcCalendarBounds(t *testing.T) { + tests := []struct { + name string + items []models.TimelineItem + currentHour int + wantStart int + wantEnd int + }{ + { + name: "no timed events returns default", + items: []models.TimelineItem{}, + currentHour: -1, + wantStart: 8, + wantEnd: 18, + }, + { + name: "single event at 10am", + items: []models.TimelineItem{ + {Time: time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC)}, + }, + currentHour: -1, + wantStart: 9, // 1 hour buffer before + wantEnd: 11, // 1 hour buffer after + }, + { + name: "includes current hour", + items: []models.TimelineItem{ + {Time: time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC)}, + }, + currentHour: 8, + wantStart: 7, // 1 hour before 8am + wantEnd: 11, // 1 hour after 10am + }, + { + name: "event with end time extends range", + items: []models.TimelineItem{ + { + Time: time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC), + EndTime: timePtr(time.Date(2023, 1, 1, 14, 0, 0, 0, time.UTC)), + }, + }, + currentHour: -1, + wantStart: 9, // 1 hour before 10am + wantEnd: 15, // 1 hour after 2pm end + }, + { + name: "all-day events are skipped", + items: []models.TimelineItem{ + {Time: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), IsAllDay: true}, + }, + currentHour: -1, + wantStart: 8, + wantEnd: 18, + }, + { + name: "overdue events are skipped", + items: []models.TimelineItem{ + {Time: time.Date(2023, 1, 1, 10, 0, 0, 0, time.UTC), IsOverdue: true}, + }, + currentHour: -1, + wantStart: 8, + wantEnd: 18, + }, + { + name: "clamps to 0-23 range", + items: []models.TimelineItem{ + {Time: time.Date(2023, 1, 1, 0, 30, 0, 0, time.UTC)}, + {Time: time.Date(2023, 1, 1, 23, 0, 0, 0, time.UTC)}, + }, + currentHour: -1, + wantStart: 0, // Can't go below 0 + wantEnd: 23, // Can't go above 23 + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + start, end := calcCalendarBounds(tc.items, tc.currentHour) + if start != tc.wantStart { + t.Errorf("Expected start %d, got %d", tc.wantStart, start) + } + if end != tc.wantEnd { + t.Errorf("Expected end %d, got %d", tc.wantEnd, end) + } + }) + } +} + +func timePtr(t time.Time) *time.Time { + return &t +} diff --git a/internal/middleware/security_test.go b/internal/middleware/security_test.go new file mode 100644 index 0000000..1717418 --- /dev/null +++ b/internal/middleware/security_test.go @@ -0,0 +1,200 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestSecurityHeaders_Debug(t *testing.T) { + handler := SecurityHeaders(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // Check common security headers + if rec.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Error("Expected X-Content-Type-Options header") + } + if rec.Header().Get("X-Frame-Options") != "DENY" { + t.Error("Expected X-Frame-Options header") + } + if rec.Header().Get("X-XSS-Protection") != "1; mode=block" { + t.Error("Expected X-XSS-Protection header") + } + + // HSTS should NOT be set in debug mode + if rec.Header().Get("Strict-Transport-Security") != "" { + t.Error("HSTS should not be set in debug mode") + } +} + +func TestSecurityHeaders_Production(t *testing.T) { + handler := SecurityHeaders(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + // HSTS should be set in production + if rec.Header().Get("Strict-Transport-Security") == "" { + t.Error("HSTS should be set in production mode") + } + + // CSP should be set + if rec.Header().Get("Content-Security-Policy") == "" { + t.Error("Expected Content-Security-Policy header") + } +} + +func TestRateLimiter_Allow(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 3, + window: time.Minute, + } + + ip := "192.168.1.1" + + // First 3 requests should be allowed + for i := 0; i < 3; i++ { + if !rl.Allow(ip) { + t.Errorf("Request %d should be allowed", i+1) + } + } + + // 4th request should be denied + if rl.Allow(ip) { + t.Error("4th request should be denied") + } +} + +func TestRateLimiter_WindowExpiry(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 2, + window: 50 * time.Millisecond, + } + + ip := "192.168.1.1" + + // Use up the limit + rl.Allow(ip) + rl.Allow(ip) + + // Should be denied + if rl.Allow(ip) { + t.Error("Should be denied when limit reached") + } + + // Wait for window to expire + time.Sleep(60 * time.Millisecond) + + // Should be allowed again + if !rl.Allow(ip) { + t.Error("Should be allowed after window expires") + } +} + +func TestRateLimiter_Limit_Middleware(t *testing.T) { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: 1, + window: time.Minute, + } + + handler := rl.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should pass + req1 := httptest.NewRequest("GET", "/", nil) + req1.RemoteAddr = "10.0.0.1:12345" + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + + if rec1.Code != http.StatusOK { + t.Errorf("First request should return 200, got %d", rec1.Code) + } + + // Second request should be rate limited + req2 := httptest.NewRequest("GET", "/", nil) + req2.RemoteAddr = "10.0.0.1:12345" + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + + if rec2.Code != http.StatusTooManyRequests { + t.Errorf("Second request should return 429, got %d", rec2.Code) + } +} + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter(10, 100*time.Millisecond) + if rl == nil { + t.Fatal("NewRateLimiter returned nil") + } + if rl.limit != 10 { + t.Errorf("Expected limit 10, got %d", rl.limit) + } + if rl.window != 100*time.Millisecond { + t.Errorf("Expected window 100ms, got %v", rl.window) + } + // Let cleanup run once + time.Sleep(150 * time.Millisecond) +} + +func TestGetIP(t *testing.T) { + tests := []struct { + name string + xff string + xri string + remoteAddr string + expected string + }{ + { + name: "X-Forwarded-For takes priority", + xff: "1.2.3.4", + xri: "5.6.7.8", + remoteAddr: "9.10.11.12", + expected: "1.2.3.4", + }, + { + name: "X-Real-IP when no XFF", + xff: "", + xri: "5.6.7.8", + remoteAddr: "9.10.11.12", + expected: "5.6.7.8", + }, + { + name: "RemoteAddr as fallback", + xff: "", + xri: "", + remoteAddr: "9.10.11.12:54321", + expected: "9.10.11.12:54321", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tc.remoteAddr + if tc.xff != "" { + req.Header.Set("X-Forwarded-For", tc.xff) + } + if tc.xri != "" { + req.Header.Set("X-Real-IP", tc.xri) + } + + ip := getIP(req) + if ip != tc.expected { + t.Errorf("Expected IP %s, got %s", tc.expected, ip) + } + }) + } +} diff --git a/internal/models/atom_test.go b/internal/models/atom_test.go new file mode 100644 index 0000000..537d232 --- /dev/null +++ b/internal/models/atom_test.go @@ -0,0 +1,291 @@ +package models + +import ( + "testing" + "time" +) + +func TestTaskToAtom(t *testing.T) { + now := time.Now() + task := Task{ + ID: "task-123", + Content: "Test task", + Description: "Task description", + DueDate: &now, + Priority: 3, + URL: "https://todoist.com/task/123", + CreatedAt: now, + IsRecurring: true, + } + + atom := TaskToAtom(task) + + if atom.ID != "task-123" { + t.Errorf("Expected ID 'task-123', got '%s'", atom.ID) + } + if atom.Title != "Test task" { + t.Errorf("Expected title 'Test task', got '%s'", atom.Title) + } + if atom.Source != SourceTodoist { + t.Errorf("Expected source Todoist, got '%s'", atom.Source) + } + if atom.Type != TypeTask { + t.Errorf("Expected type Task, got '%s'", atom.Type) + } + if atom.Priority != 3 { + t.Errorf("Expected priority 3, got %d", atom.Priority) + } + if atom.SourceIcon != "🔴" { + t.Error("Expected red circle icon") + } + if !atom.IsRecurring { + t.Error("Expected IsRecurring to be true") + } +} + +func TestTaskToAtom_PriorityClamping(t *testing.T) { + // Test priority below 1 + lowTask := Task{Priority: 0} + lowAtom := TaskToAtom(lowTask) + if lowAtom.Priority != 1 { + t.Errorf("Priority should be clamped to 1, got %d", lowAtom.Priority) + } + + // Test priority above 4 + highTask := Task{Priority: 10} + highAtom := TaskToAtom(highTask) + if highAtom.Priority != 4 { + t.Errorf("Priority should be clamped to 4, got %d", highAtom.Priority) + } +} + +func TestCardToAtom(t *testing.T) { + now := time.Now() + card := Card{ + ID: "card-456", + Name: "Test card", + ListName: "To Do", + DueDate: &now, + URL: "https://trello.com/c/456", + } + + atom := CardToAtom(card) + + if atom.ID != "card-456" { + t.Errorf("Expected ID 'card-456', got '%s'", atom.ID) + } + if atom.Title != "Test card" { + t.Errorf("Expected title 'Test card', got '%s'", atom.Title) + } + if atom.Description != "To Do" { + t.Errorf("Expected description 'To Do', got '%s'", atom.Description) + } + if atom.Source != SourceTrello { + t.Errorf("Expected source Trello, got '%s'", atom.Source) + } + if atom.Priority != 2 { + t.Errorf("Expected default priority 2, got %d", atom.Priority) + } + if atom.SourceIcon != "📋" { + t.Error("Expected clipboard icon") + } +} + +func TestMealToAtom(t *testing.T) { + date := time.Now() + meal := Meal{ + ID: "meal-789", + RecipeName: "Pasta", + MealType: "dinner", + Date: date, + RecipeURL: "https://plantoeat.com/recipe/789", + } + + atom := MealToAtom(meal) + + if atom.ID != "meal-789" { + t.Errorf("Expected ID 'meal-789', got '%s'", atom.ID) + } + if atom.Title != "Pasta" { + t.Errorf("Expected title 'Pasta', got '%s'", atom.Title) + } + if atom.Description != "dinner" { + t.Errorf("Expected description 'dinner', got '%s'", atom.Description) + } + if atom.Source != SourceMeal { + t.Errorf("Expected source Meal, got '%s'", atom.Source) + } + if atom.Type != TypeMeal { + t.Errorf("Expected type Meal, got '%s'", atom.Type) + } + if atom.Priority != 1 { + t.Errorf("Expected priority 1, got %d", atom.Priority) + } +} + +func TestBugToAtom(t *testing.T) { + now := time.Now() + bug := Bug{ + ID: 42, + Description: "Something is broken", + CreatedAt: now, + } + + atom := BugToAtom(bug) + + if atom.ID != "bug-42" { + t.Errorf("Expected ID 'bug-42', got '%s'", atom.ID) + } + if atom.Title != "Something is broken" { + t.Errorf("Expected title 'Something is broken', got '%s'", atom.Title) + } + if atom.Source != SourceBug { + t.Errorf("Expected source Bug, got '%s'", atom.Source) + } + if atom.Type != TypeBug { + t.Errorf("Expected type Bug, got '%s'", atom.Type) + } + if atom.Priority != 3 { + t.Errorf("Expected high priority 3, got %d", atom.Priority) + } + if atom.SourceIcon != "🐛" { + t.Error("Expected bug icon") + } +} + +func TestAtom_ComputeUIFields(t *testing.T) { + // Test nil due date + t.Run("nil due date", func(t *testing.T) { + atom := Atom{} + atom.ComputeUIFields() + // Should not panic and fields should remain default + if atom.IsOverdue || atom.IsFuture || atom.HasSetTime { + t.Error("Fields should be false for nil due date") + } + }) + + // Test with due date at midnight (no specific time) + t.Run("midnight due date", func(t *testing.T) { + now := time.Now() + midnightTomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.UTC) + atom := Atom{DueDate: &midnightTomorrow} + atom.ComputeUIFields() + if atom.HasSetTime { + t.Error("HasSetTime should be false for midnight") + } + }) + + // Test with specific time + t.Run("specific time", func(t *testing.T) { + now := time.Now() + withTime := time.Date(now.Year(), now.Month(), now.Day()+1, 14, 30, 0, 0, time.UTC) + atom := Atom{DueDate: &withTime} + atom.ComputeUIFields() + if !atom.HasSetTime { + t.Error("HasSetTime should be true for 14:30") + } + }) +} + +func TestTimelineItem_ComputeDaySection(t *testing.T) { + // Use UTC since that's the default display timezone when not configured + tz := time.UTC + now := time.Now().In(tz) + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, tz) + tomorrow := today.AddDate(0, 0, 1) + nextWeek := today.AddDate(0, 0, 7) + yesterday := today.AddDate(0, 0, -1) + + tests := []struct { + name string + itemTime time.Time + wantSection DaySection + wantOverdue bool + wantAllDay bool + }{ + { + name: "today with specific time", + itemTime: time.Date(now.Year(), now.Month(), now.Day(), 14, 30, 0, 0, tz), + wantSection: DaySectionToday, + wantOverdue: false, + wantAllDay: false, + }, + { + name: "today all day (midnight)", + itemTime: today, + wantSection: DaySectionToday, + wantOverdue: false, + wantAllDay: true, + }, + { + name: "tomorrow", + itemTime: tomorrow.Add(10 * time.Hour), + wantSection: DaySectionTomorrow, + wantOverdue: false, + wantAllDay: false, + }, + { + name: "later (next week)", + itemTime: nextWeek, + wantSection: DaySectionLater, + wantOverdue: false, + wantAllDay: true, + }, + { + name: "overdue (yesterday)", + itemTime: yesterday, + wantSection: DaySectionToday, // Overdue items show in today section + wantOverdue: true, + wantAllDay: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + item := &TimelineItem{Time: tc.itemTime} + item.ComputeDaySection(now) + + if item.DaySection != tc.wantSection { + t.Errorf("Expected section %s, got %s", tc.wantSection, item.DaySection) + } + if item.IsOverdue != tc.wantOverdue { + t.Errorf("Expected overdue=%v, got %v", tc.wantOverdue, item.IsOverdue) + } + if item.IsAllDay != tc.wantAllDay { + t.Errorf("Expected allDay=%v, got %v", tc.wantAllDay, item.IsAllDay) + } + }) + } +} + +func TestCacheMetadata_IsCacheValid(t *testing.T) { + t.Run("valid cache", func(t *testing.T) { + cm := CacheMetadata{ + LastFetch: time.Now(), + TTLMinutes: 5, + } + if !cm.IsCacheValid() { + t.Error("Cache should be valid") + } + }) + + t.Run("expired cache", func(t *testing.T) { + cm := CacheMetadata{ + LastFetch: time.Now().Add(-10 * time.Minute), + TTLMinutes: 5, + } + if cm.IsCacheValid() { + t.Error("Cache should be expired") + } + }) + + t.Run("zero TTL", func(t *testing.T) { + cm := CacheMetadata{ + LastFetch: time.Now().Add(-1 * time.Second), + TTLMinutes: 0, + } + if cm.IsCacheValid() { + t.Error("Cache with zero TTL should be expired") + } + }) +} diff --git a/internal/store/sqlite_test.go b/internal/store/sqlite_test.go index 9aef09d..69d188a 100644 --- a/internal/store/sqlite_test.go +++ b/internal/store/sqlite_test.go @@ -689,3 +689,989 @@ func TestResolveBug_NonExistent(t *testing.T) { t.Errorf("ResolveBug on non-existent bug should not error, got: %v", err) } } + +// ============================================================================= +// User Shopping Items Tests +// ============================================================================= + +func setupTestStoreWithShopping(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS user_shopping_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + store TEXT NOT NULL, + checked INTEGER DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + CREATE TABLE IF NOT EXISTS shopping_item_checks ( + source TEXT NOT NULL, + item_id TEXT NOT NULL, + checked INTEGER DEFAULT 0, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (source, item_id) + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestUserShoppingItems_CRUD(t *testing.T) { + store := setupTestStoreWithShopping(t) + defer func() { _ = store.Close() }() + + // Save items + if err := store.SaveUserShoppingItem("Milk", "Costco"); err != nil { + t.Fatalf("Failed to save item: %v", err) + } + if err := store.SaveUserShoppingItem("Bread", "Safeway"); err != nil { + t.Fatalf("Failed to save second item: %v", err) + } + + // Get items + items, err := store.GetUserShoppingItems() + if err != nil { + t.Fatalf("Failed to get items: %v", err) + } + if len(items) != 2 { + t.Errorf("Expected 2 items, got %d", len(items)) + } + + // Verify item data + var milkItem UserShoppingItem + for _, item := range items { + if item.Name == "Milk" { + milkItem = item + break + } + } + if milkItem.Name != "Milk" { + t.Error("Could not find Milk item") + } + if milkItem.Store != "Costco" { + t.Errorf("Expected store 'Costco', got '%s'", milkItem.Store) + } + if milkItem.Checked { + t.Error("New item should not be checked") + } + + // Toggle item + if err := store.ToggleUserShoppingItem(milkItem.ID, true); err != nil { + t.Fatalf("Failed to toggle item: %v", err) + } + + items, _ = store.GetUserShoppingItems() + for _, item := range items { + if item.ID == milkItem.ID && !item.Checked { + t.Error("Item should be checked after toggle") + } + } + + // Delete item + if err := store.DeleteUserShoppingItem(milkItem.ID); err != nil { + t.Fatalf("Failed to delete item: %v", err) + } + + items, _ = store.GetUserShoppingItems() + if len(items) != 1 { + t.Errorf("Expected 1 item after delete, got %d", len(items)) + } +} + +func TestShoppingItemChecks_ExternalSources(t *testing.T) { + store := setupTestStoreWithShopping(t) + defer func() { _ = store.Close() }() + + // Set checked for trello item + if err := store.SetShoppingItemChecked("trello", "card-123", true); err != nil { + t.Fatalf("Failed to set trello checked: %v", err) + } + + // Set checked for plantoeat item + if err := store.SetShoppingItemChecked("plantoeat", "pte-456", true); err != nil { + t.Fatalf("Failed to set plantoeat checked: %v", err) + } + + // Get trello checks + trelloChecks, err := store.GetShoppingItemChecks("trello") + if err != nil { + t.Fatalf("Failed to get trello checks: %v", err) + } + if !trelloChecks["card-123"] { + t.Error("Expected trello card to be checked") + } + + // Get plantoeat checks + pteChecks, err := store.GetShoppingItemChecks("plantoeat") + if err != nil { + t.Fatalf("Failed to get plantoeat checks: %v", err) + } + if !pteChecks["pte-456"] { + t.Error("Expected plantoeat item to be checked") + } + + // Uncheck trello item + if err := store.SetShoppingItemChecked("trello", "card-123", false); err != nil { + t.Fatalf("Failed to uncheck trello item: %v", err) + } + + trelloChecks, _ = store.GetShoppingItemChecks("trello") + if trelloChecks["card-123"] { + t.Error("Trello item should be unchecked after update") + } +} + +// ============================================================================= +// Feature Toggles Tests +// ============================================================================= + +func setupTestStoreWithFeatureToggles(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS feature_toggles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + description TEXT, + enabled BOOLEAN DEFAULT FALSE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestFeatureToggles_CRUD(t *testing.T) { + store := setupTestStoreWithFeatureToggles(t) + defer func() { _ = store.Close() }() + + // Create feature toggle + if err := store.CreateFeatureToggle("new_feature", "A new feature", false); err != nil { + t.Fatalf("Failed to create feature toggle: %v", err) + } + + // Get all toggles + toggles, err := store.GetFeatureToggles() + if err != nil { + t.Fatalf("Failed to get feature toggles: %v", err) + } + if len(toggles) != 1 { + t.Errorf("Expected 1 toggle, got %d", len(toggles)) + } + if toggles[0].Name != "new_feature" { + t.Errorf("Expected name 'new_feature', got '%s'", toggles[0].Name) + } + if toggles[0].Enabled { + t.Error("New feature should be disabled") + } + + // Check if enabled + if store.IsFeatureEnabled("new_feature") { + t.Error("IsFeatureEnabled should return false for disabled feature") + } + + // Enable feature + if err := store.SetFeatureEnabled("new_feature", true); err != nil { + t.Fatalf("Failed to enable feature: %v", err) + } + + if !store.IsFeatureEnabled("new_feature") { + t.Error("IsFeatureEnabled should return true after enabling") + } + + // Delete feature + if err := store.DeleteFeatureToggle("new_feature"); err != nil { + t.Fatalf("Failed to delete feature toggle: %v", err) + } + + toggles, _ = store.GetFeatureToggles() + if len(toggles) != 0 { + t.Errorf("Expected 0 toggles after delete, got %d", len(toggles)) + } +} + +func TestIsFeatureEnabled_NonExistent(t *testing.T) { + store := setupTestStoreWithFeatureToggles(t) + defer func() { _ = store.Close() }() + + // Non-existent feature should return false + if store.IsFeatureEnabled("does_not_exist") { + t.Error("Non-existent feature should return false") + } +} + +// ============================================================================= +// Completed Tasks Tests +// ============================================================================= + +func setupTestStoreWithCompletedTasks(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS completed_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT NOT NULL, + source_id TEXT NOT NULL, + title TEXT NOT NULL, + due_date TEXT, + completed_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(source, source_id) + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestCompletedTasks_SaveAndGet(t *testing.T) { + store := setupTestStoreWithCompletedTasks(t) + defer func() { _ = store.Close() }() + + now := time.Now() + + // Save completed task with due date + if err := store.SaveCompletedTask("todoist", "task-123", "Buy groceries", &now); err != nil { + t.Fatalf("Failed to save completed task: %v", err) + } + + // Save completed task without due date + if err := store.SaveCompletedTask("trello", "card-456", "Review PR", nil); err != nil { + t.Fatalf("Failed to save second completed task: %v", err) + } + + // Get completed tasks + tasks, err := store.GetCompletedTasks(10) + if err != nil { + t.Fatalf("Failed to get completed tasks: %v", err) + } + if len(tasks) != 2 { + t.Errorf("Expected 2 completed tasks, got %d", len(tasks)) + } + + // Verify task data + var todoistTask models.CompletedTask + for _, task := range tasks { + if task.Source == "todoist" { + todoistTask = task + break + } + } + if todoistTask.Title != "Buy groceries" { + t.Errorf("Expected title 'Buy groceries', got '%s'", todoistTask.Title) + } + if todoistTask.DueDate == nil { + t.Error("Expected due date to be set") + } +} + +func TestCompletedTasks_Limit(t *testing.T) { + store := setupTestStoreWithCompletedTasks(t) + defer func() { _ = store.Close() }() + + // Save multiple tasks + for i := 0; i < 10; i++ { + _ = store.SaveCompletedTask("todoist", "task-"+string(rune('0'+i)), "Task "+string(rune('0'+i)), nil) + } + + // Get with limit + tasks, _ := store.GetCompletedTasks(5) + if len(tasks) != 5 { + t.Errorf("Expected 5 tasks with limit, got %d", len(tasks)) + } +} + +// ============================================================================= +// Source Configuration Tests +// ============================================================================= + +func setupTestStoreWithSourceConfig(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS source_config ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT NOT NULL, + item_type TEXT NOT NULL, + item_id TEXT NOT NULL, + item_name TEXT NOT NULL, + enabled BOOLEAN DEFAULT TRUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + UNIQUE(source, item_type, item_id) + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestSourceConfig_UpsertAndGet(t *testing.T) { + store := setupTestStoreWithSourceConfig(t) + defer func() { _ = store.Close() }() + + // Upsert configs + cfg1 := models.SourceConfig{ + Source: "trello", + ItemType: "board", + ItemID: "board-123", + ItemName: "Work Board", + Enabled: true, + } + if err := store.UpsertSourceConfig(cfg1); err != nil { + t.Fatalf("Failed to upsert config: %v", err) + } + + cfg2 := models.SourceConfig{ + Source: "trello", + ItemType: "board", + ItemID: "board-456", + ItemName: "Personal Board", + Enabled: false, + } + if err := store.UpsertSourceConfig(cfg2); err != nil { + t.Fatalf("Failed to upsert second config: %v", err) + } + + // Get all configs + configs, err := store.GetSourceConfigs() + if err != nil { + t.Fatalf("Failed to get configs: %v", err) + } + if len(configs) != 2 { + t.Errorf("Expected 2 configs, got %d", len(configs)) + } + + // Get by source + trelloConfigs, err := store.GetSourceConfigsBySource("trello") + if err != nil { + t.Fatalf("Failed to get trello configs: %v", err) + } + if len(trelloConfigs) != 2 { + t.Errorf("Expected 2 trello configs, got %d", len(trelloConfigs)) + } + + // Get enabled IDs + enabledIDs, err := store.GetEnabledSourceIDs("trello", "board") + if err != nil { + t.Fatalf("Failed to get enabled IDs: %v", err) + } + if len(enabledIDs) != 1 { + t.Errorf("Expected 1 enabled ID, got %d", len(enabledIDs)) + } + if enabledIDs[0] != "board-123" { + t.Errorf("Expected 'board-123', got '%s'", enabledIDs[0]) + } +} + +func TestSourceConfig_SetEnabled(t *testing.T) { + store := setupTestStoreWithSourceConfig(t) + defer func() { _ = store.Close() }() + + // Create a config + cfg := models.SourceConfig{ + Source: "calendar", + ItemType: "calendar", + ItemID: "cal-1", + ItemName: "Primary", + Enabled: true, + } + _ = store.UpsertSourceConfig(cfg) + + // Disable it + if err := store.SetSourceConfigEnabled("calendar", "calendar", "cal-1", false); err != nil { + t.Fatalf("Failed to set enabled: %v", err) + } + + // Verify + enabledIDs, _ := store.GetEnabledSourceIDs("calendar", "calendar") + if len(enabledIDs) != 0 { + t.Error("Expected no enabled calendars after disabling") + } +} + +// ============================================================================= +// Cache Metadata Tests +// ============================================================================= + +func setupTestStoreWithCacheMetadata(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS cache_metadata ( + key TEXT PRIMARY KEY, + last_fetch DATETIME NOT NULL, + ttl_minutes INTEGER DEFAULT 5, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestCacheMetadata_UpdateAndCheck(t *testing.T) { + store := setupTestStoreWithCacheMetadata(t) + defer func() { _ = store.Close() }() + + // Initially no metadata + valid, _ := store.IsCacheValid("test_key") + if valid { + t.Error("Cache should be invalid when no metadata exists") + } + + // Update cache metadata + if err := store.UpdateCacheMetadata("test_key", 5); err != nil { + t.Fatalf("Failed to update cache metadata: %v", err) + } + + // Now cache should be valid + valid, err := store.IsCacheValid("test_key") + if err != nil { + t.Fatalf("Failed to check cache validity: %v", err) + } + if !valid { + t.Error("Cache should be valid after update") + } + + // Get metadata + metadata, err := store.GetCacheMetadata("test_key") + if err != nil { + t.Fatalf("Failed to get cache metadata: %v", err) + } + if metadata == nil { + t.Fatal("Expected metadata to exist") + } + if metadata.TTLMinutes != 5 { + t.Errorf("Expected TTL 5, got %d", metadata.TTLMinutes) + } + + // Invalidate cache + if err := store.InvalidateCache("test_key"); err != nil { + t.Fatalf("Failed to invalidate cache: %v", err) + } + + valid, _ = store.IsCacheValid("test_key") + if valid { + t.Error("Cache should be invalid after invalidation") + } +} + +func TestCacheMetadata_ExpiredCache(t *testing.T) { + store := setupTestStoreWithCacheMetadata(t) + defer func() { _ = store.Close() }() + + // Insert old cache entry directly + oldTime := time.Now().Add(-10 * time.Minute) + _, err := store.db.Exec(` + INSERT INTO cache_metadata (key, last_fetch, ttl_minutes) + VALUES (?, ?, ?) + `, "expired_key", oldTime, 5) + if err != nil { + t.Fatalf("Failed to insert old metadata: %v", err) + } + + // Cache should be invalid (expired) + valid, _ := store.IsCacheValid("expired_key") + if valid { + t.Error("Expired cache should be invalid") + } +} + +// ============================================================================= +// Sync Token Tests +// ============================================================================= + +func setupTestStoreWithSyncTokens(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS sync_tokens ( + service TEXT PRIMARY KEY, + token TEXT NOT NULL, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestSyncTokens_SetGetClear(t *testing.T) { + store := setupTestStoreWithSyncTokens(t) + defer func() { _ = store.Close() }() + + // Get non-existent token + token, err := store.GetSyncToken("todoist") + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + if token != "" { + t.Errorf("Expected empty token, got '%s'", token) + } + + // Set token + if err := store.SetSyncToken("todoist", "sync-token-123"); err != nil { + t.Fatalf("Failed to set token: %v", err) + } + + // Get token + token, err = store.GetSyncToken("todoist") + if err != nil { + t.Fatalf("Failed to get token after set: %v", err) + } + if token != "sync-token-123" { + t.Errorf("Expected 'sync-token-123', got '%s'", token) + } + + // Clear token + if err := store.ClearSyncToken("todoist"); err != nil { + t.Fatalf("Failed to clear token: %v", err) + } + + token, _ = store.GetSyncToken("todoist") + if token != "" { + t.Errorf("Expected empty token after clear, got '%s'", token) + } +} + +// ============================================================================= +// Agent Session Tests +// ============================================================================= + +func setupTestStoreWithAgents(t *testing.T) *Store { + t.Helper() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + db.SetMaxOpenConns(1) + store := &Store{db: db} + + schema := ` + CREATE TABLE IF NOT EXISTS agents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + agent_id TEXT UNIQUE NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_seen DATETIME, + trusted BOOLEAN DEFAULT 1 + ); + CREATE TABLE IF NOT EXISTS agent_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + request_token TEXT UNIQUE NOT NULL, + agent_name TEXT NOT NULL, + agent_id TEXT NOT NULL, + status TEXT DEFAULT 'pending', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME NOT NULL, + session_token TEXT, + session_expires_at DATETIME + ); + ` + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + return store +} + +func TestAgentSession_CreateAndRetrieve(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + expiresAt := time.Now().Add(5 * time.Minute) + session := &models.AgentSession{ + RequestToken: "req-token-123", + AgentName: "TestAgent", + AgentID: "agent-uuid-123", + ExpiresAt: expiresAt, + } + + // Create session + if err := store.CreateAgentSession(session); err != nil { + t.Fatalf("Failed to create session: %v", err) + } + if session.ID == 0 { + t.Error("Session ID should be set after create") + } + + // Get by request token + retrieved, err := store.GetAgentSessionByRequestToken("req-token-123") + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if retrieved == nil { + t.Fatal("Expected session to exist") + } + if retrieved.AgentName != "TestAgent" { + t.Errorf("Expected name 'TestAgent', got '%s'", retrieved.AgentName) + } + if retrieved.Status != "pending" { + t.Errorf("Expected status 'pending', got '%s'", retrieved.Status) + } + + // Get pending by agent ID + pending, err := store.GetPendingAgentSessionByAgentID("agent-uuid-123") + if err != nil { + t.Fatalf("Failed to get pending session: %v", err) + } + if pending == nil { + t.Fatal("Expected pending session to exist") + } +} + +func TestAgentSession_ApproveAndDeny(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Create two sessions + session1 := &models.AgentSession{ + RequestToken: "approve-token", + AgentName: "Agent1", + AgentID: "agent-1", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + session2 := &models.AgentSession{ + RequestToken: "deny-token", + AgentName: "Agent2", + AgentID: "agent-2", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = store.CreateAgentSession(session1) + _ = store.CreateAgentSession(session2) + + // Approve session1 + sessionExpiry := time.Now().Add(1 * time.Hour) + if err := store.ApproveAgentSession("approve-token", "session-token-abc", sessionExpiry); err != nil { + t.Fatalf("Failed to approve session: %v", err) + } + + // Verify approval + approved, _ := store.GetAgentSessionByRequestToken("approve-token") + if approved.Status != "approved" { + t.Errorf("Expected status 'approved', got '%s'", approved.Status) + } + if approved.SessionToken != "session-token-abc" { + t.Errorf("Expected session token 'session-token-abc', got '%s'", approved.SessionToken) + } + + // Deny session2 + if err := store.DenyAgentSession("deny-token"); err != nil { + t.Fatalf("Failed to deny session: %v", err) + } + + denied, _ := store.GetAgentSessionByRequestToken("deny-token") + if denied.Status != "denied" { + t.Errorf("Expected status 'denied', got '%s'", denied.Status) + } +} + +func TestAgentSession_GetBySessionToken(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + session := &models.AgentSession{ + RequestToken: "req-for-session", + AgentName: "SessionAgent", + AgentID: "session-agent", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = store.CreateAgentSession(session) + _ = store.ApproveAgentSession("req-for-session", "active-session", time.Now().Add(1*time.Hour)) + + // Get by session token + retrieved, err := store.GetAgentSessionBySessionToken("active-session") + if err != nil { + t.Fatalf("Failed to get by session token: %v", err) + } + if retrieved == nil { + t.Fatal("Expected session to exist") + } + if retrieved.AgentName != "SessionAgent" { + t.Errorf("Expected 'SessionAgent', got '%s'", retrieved.AgentName) + } +} + +func TestAgentSession_GetPending(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Create pending sessions + for i := 0; i < 3; i++ { + session := &models.AgentSession{ + RequestToken: "pending-" + string(rune('0'+i)), + AgentName: "Agent" + string(rune('0'+i)), + AgentID: "agent-" + string(rune('0'+i)), + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = store.CreateAgentSession(session) + } + + // Get pending sessions + pending, err := store.GetPendingAgentSessions() + if err != nil { + t.Fatalf("Failed to get pending sessions: %v", err) + } + if len(pending) != 3 { + t.Errorf("Expected 3 pending sessions, got %d", len(pending)) + } +} + +func TestAgentSession_Invalidate(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Create sessions for same agent + for i := 0; i < 2; i++ { + session := &models.AgentSession{ + RequestToken: "inv-" + string(rune('0'+i)), + AgentName: "SameAgent", + AgentID: "same-agent", + ExpiresAt: time.Now().Add(5 * time.Minute), + } + _ = store.CreateAgentSession(session) + } + + // Invalidate all sessions for agent + if err := store.InvalidatePreviousAgentSessions("same-agent"); err != nil { + t.Fatalf("Failed to invalidate sessions: %v", err) + } + + // Verify no pending sessions + pending, _ := store.GetPendingAgentSessions() + for _, s := range pending { + if s.AgentID == "same-agent" { + t.Error("Session should be invalidated") + } + } +} + +// ============================================================================= +// Agent Tests +// ============================================================================= + +func TestAgent_CreateAndRetrieve(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Create agent + if err := store.CreateOrUpdateAgent("TestBot", "bot-uuid-123"); err != nil { + t.Fatalf("Failed to create agent: %v", err) + } + + // Get by agent ID + agent, err := store.GetAgentByAgentID("bot-uuid-123") + if err != nil { + t.Fatalf("Failed to get agent: %v", err) + } + if agent == nil { + t.Fatal("Expected agent to exist") + } + if agent.Name != "TestBot" { + t.Errorf("Expected name 'TestBot', got '%s'", agent.Name) + } + if !agent.Trusted { + t.Error("New agent should be trusted by default") + } + + // Get by name + byName, err := store.GetAgentByName("TestBot") + if err != nil { + t.Fatalf("Failed to get agent by name: %v", err) + } + if byName == nil { + t.Fatal("Expected agent to exist by name") + } + + // Get all agents + all, err := store.GetAllAgents() + if err != nil { + t.Fatalf("Failed to get all agents: %v", err) + } + if len(all) != 1 { + t.Errorf("Expected 1 agent, got %d", len(all)) + } +} + +func TestAgent_UpdateLastSeen(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + _ = store.CreateOrUpdateAgent("SeenBot", "seen-uuid") + + // Update last seen + if err := store.UpdateAgentLastSeen("seen-uuid"); err != nil { + t.Fatalf("Failed to update last seen: %v", err) + } + + agent, _ := store.GetAgentByAgentID("seen-uuid") + if agent.LastSeen == nil { + t.Error("LastSeen should be set after update") + } +} + +func TestAgent_TrustLevels(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Check trust for unknown agent (new) + trust, err := store.CheckAgentTrust("UnknownBot", "unknown-uuid") + if err != nil { + t.Fatalf("Failed to check trust: %v", err) + } + if trust != models.AgentTrustNew { + t.Errorf("Expected AgentTrustNew, got %v", trust) + } + + // Create agent + _ = store.CreateOrUpdateAgent("TrustBot", "trust-uuid") + + // Check trust for recognized agent + trust, _ = store.CheckAgentTrust("TrustBot", "trust-uuid") + if trust != models.AgentTrustRecognized { + t.Errorf("Expected AgentTrustRecognized, got %v", trust) + } + + // Check trust for suspicious agent (same name, different uuid) + trust, _ = store.CheckAgentTrust("TrustBot", "different-uuid") + if trust != models.AgentTrustSuspicious { + t.Errorf("Expected AgentTrustSuspicious, got %v", trust) + } +} + +func TestAgent_Revoke(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + _ = store.CreateOrUpdateAgent("RevokeBot", "revoke-uuid") + + // Verify agent exists + agent, _ := store.GetAgentByAgentID("revoke-uuid") + if agent == nil { + t.Fatal("Agent should exist") + } + + // Revoke agent + if err := store.RevokeAgent("revoke-uuid"); err != nil { + t.Fatalf("Failed to revoke agent: %v", err) + } + + // After revoke, agent should still exist but be in different state + // (revoke doesn't delete, just marks somehow - let's verify it doesn't error) +} + +func TestAgent_NonExistent(t *testing.T) { + store := setupTestStoreWithAgents(t) + defer func() { _ = store.Close() }() + + // Get non-existent agent + agent, err := store.GetAgentByAgentID("does-not-exist") + if err != nil { + t.Fatalf("Should not error for non-existent agent: %v", err) + } + if agent != nil { + t.Error("Agent should be nil for non-existent") + } + + // Get non-existent by name + byName, err := store.GetAgentByName("unknown-name") + if err != nil { + t.Fatalf("Should not error for non-existent name: %v", err) + } + if byName != nil { + t.Error("Agent should be nil for non-existent name") + } + + // Check trust for non-existent (should be new) + trust, _ := store.CheckAgentTrust("UnknownBot", "unknown-uuid") + if trust != models.AgentTrustNew { + t.Errorf("Expected AgentTrustNew for unknown, got %v", trust) + } +} |
