package api import ( "context" "os" "path/filepath" "testing" "time" ) // TestGetNotes_SymlinkSecurity verifies that GetNotes does not follow symlinks // to prevent path traversal attacks func TestGetNotes_SymlinkSecurity(t *testing.T) { // Create temporary directories tempDir := t.TempDir() vaultDir := filepath.Join(tempDir, "vault") outsideDir := filepath.Join(tempDir, "outside") // Create vault and outside directories if err := os.Mkdir(vaultDir, 0755); err != nil { t.Fatalf("Failed to create vault directory: %v", err) } if err := os.Mkdir(outsideDir, 0755); err != nil { t.Fatalf("Failed to create outside directory: %v", err) } // Create a valid markdown file inside the vault validFile := filepath.Join(vaultDir, "valid-note.md") validContent := "# Valid Note\n\nThis is a valid note inside the vault." if err := os.WriteFile(validFile, []byte(validContent), 0644); err != nil { t.Fatalf("Failed to create valid markdown file: %v", err) } // Create a secret file outside the vault secretFile := filepath.Join(outsideDir, "secret.txt") secretContent := "This is a secret file outside the vault that should not be accessible." if err := os.WriteFile(secretFile, []byte(secretContent), 0644); err != nil { t.Fatalf("Failed to create secret file: %v", err) } // Create a symlink inside the vault pointing to the secret file symlinkPath := filepath.Join(vaultDir, "symlink-note.md") if err := os.Symlink(secretFile, symlinkPath); err != nil { t.Skipf("Skipping test: unable to create symlink (may not be supported on this system): %v", err) } // Initialize Obsidian client client := NewObsidianClient(vaultDir) // Call GetNotes ctx := context.Background() notes, err := client.GetNotes(ctx, 10) if err != nil { t.Fatalf("GetNotes returned error: %v", err) } // Verify that only the valid note is returned if len(notes) != 1 { t.Errorf("Expected 1 note, got %d notes", len(notes)) for i, note := range notes { t.Logf("Note %d: %s (path: %s)", i, note.Filename, note.Path) } t.Fatalf("Test failed: symlink was followed or wrong number of notes returned") } // Verify the returned note is the valid one note := notes[0] if note.Filename != "valid-note.md" { t.Errorf("Expected filename 'valid-note.md', got '%s'", note.Filename) } if note.Title != "Valid Note" { t.Errorf("Expected title 'Valid Note', got '%s'", note.Title) } // Ensure the content does not contain the secret text if containsString(note.Content, "secret") { t.Errorf("Note content contains 'secret', which suggests symlink was followed: %s", note.Content) } } // TestGetNotes_BasicFunctionality tests basic GetNotes functionality func TestGetNotes_BasicFunctionality(t *testing.T) { // Create temporary vault directory vaultDir := t.TempDir() // Create multiple markdown files with different modification times files := []struct { name string content string delay time.Duration }{ {"oldest.md", "# Oldest Note\n\nThis is the oldest note.", 0}, {"middle.md", "# Middle Note\n\nThis is a middle note.", 10 * time.Millisecond}, {"newest.md", "# Newest Note\n\nThis is the newest note.", 20 * time.Millisecond}, } for _, file := range files { time.Sleep(file.delay) path := filepath.Join(vaultDir, file.name) if err := os.WriteFile(path, []byte(file.content), 0644); err != nil { t.Fatalf("Failed to create file %s: %v", file.name, err) } } // Initialize Obsidian client client := NewObsidianClient(vaultDir) // Test with limit ctx := context.Background() notes, err := client.GetNotes(ctx, 2) if err != nil { t.Fatalf("GetNotes returned error: %v", err) } // Should return 2 most recent notes if len(notes) != 2 { t.Errorf("Expected 2 notes with limit, got %d", len(notes)) } // Verify order (newest first) if len(notes) >= 2 { if notes[0].Filename != "newest.md" { t.Errorf("Expected first note to be 'newest.md', got '%s'", notes[0].Filename) } if notes[1].Filename != "middle.md" { t.Errorf("Expected second note to be 'middle.md', got '%s'", notes[1].Filename) } } // Test without limit allNotes, err := client.GetNotes(ctx, 0) if err != nil { t.Fatalf("GetNotes without limit returned error: %v", err) } if len(allNotes) != 3 { t.Errorf("Expected 3 notes without limit, got %d", len(allNotes)) } } // containsString checks if haystack contains needle (case-insensitive) func containsString(haystack, needle string) bool { return len(haystack) > 0 && len(needle) > 0 && (haystack == needle || len(haystack) >= len(needle) && hasSubstring(haystack, needle)) } func hasSubstring(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }