summaryrefslogtreecommitdiff
path: root/internal/api/obsidian_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/obsidian_test.go')
-rw-r--r--internal/api/obsidian_test.go155
1 files changed, 155 insertions, 0 deletions
diff --git a/internal/api/obsidian_test.go b/internal/api/obsidian_test.go
new file mode 100644
index 0000000..3509594
--- /dev/null
+++ b/internal/api/obsidian_test.go
@@ -0,0 +1,155 @@
+package api
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+// TestGetNotes_SymlinkSecurity verifies that GetNotes does not follow symlinks
+// to prevent path traversal attacks
+func TestGetNotes_SymlinkSecurity(t *testing.T) {
+ // Create temporary directories
+ tempDir := t.TempDir()
+ vaultDir := filepath.Join(tempDir, "vault")
+ outsideDir := filepath.Join(tempDir, "outside")
+
+ // Create vault and outside directories
+ if err := os.Mkdir(vaultDir, 0755); err != nil {
+ t.Fatalf("Failed to create vault directory: %v", err)
+ }
+ if err := os.Mkdir(outsideDir, 0755); err != nil {
+ t.Fatalf("Failed to create outside directory: %v", err)
+ }
+
+ // Create a valid markdown file inside the vault
+ validFile := filepath.Join(vaultDir, "valid-note.md")
+ validContent := "# Valid Note\n\nThis is a valid note inside the vault."
+ if err := os.WriteFile(validFile, []byte(validContent), 0644); err != nil {
+ t.Fatalf("Failed to create valid markdown file: %v", err)
+ }
+
+ // Create a secret file outside the vault
+ secretFile := filepath.Join(outsideDir, "secret.txt")
+ secretContent := "This is a secret file outside the vault that should not be accessible."
+ if err := os.WriteFile(secretFile, []byte(secretContent), 0644); err != nil {
+ t.Fatalf("Failed to create secret file: %v", err)
+ }
+
+ // Create a symlink inside the vault pointing to the secret file
+ symlinkPath := filepath.Join(vaultDir, "symlink-note.md")
+ if err := os.Symlink(secretFile, symlinkPath); err != nil {
+ t.Skipf("Skipping test: unable to create symlink (may not be supported on this system): %v", err)
+ }
+
+ // Initialize Obsidian client
+ client := NewObsidianClient(vaultDir)
+
+ // Call GetNotes
+ ctx := context.Background()
+ notes, err := client.GetNotes(ctx, 10)
+ if err != nil {
+ t.Fatalf("GetNotes returned error: %v", err)
+ }
+
+ // Verify that only the valid note is returned
+ if len(notes) != 1 {
+ t.Errorf("Expected 1 note, got %d notes", len(notes))
+ for i, note := range notes {
+ t.Logf("Note %d: %s (path: %s)", i, note.Filename, note.Path)
+ }
+ t.Fatalf("Test failed: symlink was followed or wrong number of notes returned")
+ }
+
+ // Verify the returned note is the valid one
+ note := notes[0]
+ if note.Filename != "valid-note.md" {
+ t.Errorf("Expected filename 'valid-note.md', got '%s'", note.Filename)
+ }
+ if note.Title != "Valid Note" {
+ t.Errorf("Expected title 'Valid Note', got '%s'", note.Title)
+ }
+
+ // Ensure the content does not contain the secret text
+ if containsString(note.Content, "secret") {
+ t.Errorf("Note content contains 'secret', which suggests symlink was followed: %s", note.Content)
+ }
+}
+
+// TestGetNotes_BasicFunctionality tests basic GetNotes functionality
+func TestGetNotes_BasicFunctionality(t *testing.T) {
+ // Create temporary vault directory
+ vaultDir := t.TempDir()
+
+ // Create multiple markdown files with different modification times
+ files := []struct {
+ name string
+ content string
+ delay time.Duration
+ }{
+ {"oldest.md", "# Oldest Note\n\nThis is the oldest note.", 0},
+ {"middle.md", "# Middle Note\n\nThis is a middle note.", 10 * time.Millisecond},
+ {"newest.md", "# Newest Note\n\nThis is the newest note.", 20 * time.Millisecond},
+ }
+
+ for _, file := range files {
+ time.Sleep(file.delay)
+ path := filepath.Join(vaultDir, file.name)
+ if err := os.WriteFile(path, []byte(file.content), 0644); err != nil {
+ t.Fatalf("Failed to create file %s: %v", file.name, err)
+ }
+ }
+
+ // Initialize Obsidian client
+ client := NewObsidianClient(vaultDir)
+
+ // Test with limit
+ ctx := context.Background()
+ notes, err := client.GetNotes(ctx, 2)
+ if err != nil {
+ t.Fatalf("GetNotes returned error: %v", err)
+ }
+
+ // Should return 2 most recent notes
+ if len(notes) != 2 {
+ t.Errorf("Expected 2 notes with limit, got %d", len(notes))
+ }
+
+ // Verify order (newest first)
+ if len(notes) >= 2 {
+ if notes[0].Filename != "newest.md" {
+ t.Errorf("Expected first note to be 'newest.md', got '%s'", notes[0].Filename)
+ }
+ if notes[1].Filename != "middle.md" {
+ t.Errorf("Expected second note to be 'middle.md', got '%s'", notes[1].Filename)
+ }
+ }
+
+ // Test without limit
+ allNotes, err := client.GetNotes(ctx, 0)
+ if err != nil {
+ t.Fatalf("GetNotes without limit returned error: %v", err)
+ }
+
+ if len(allNotes) != 3 {
+ t.Errorf("Expected 3 notes without limit, got %d", len(allNotes))
+ }
+}
+
+// containsString checks if haystack contains needle (case-insensitive)
+func containsString(haystack, needle string) bool {
+ return len(haystack) > 0 && len(needle) > 0 &&
+ (haystack == needle || len(haystack) >= len(needle) &&
+ hasSubstring(haystack, needle))
+}
+
+func hasSubstring(s, substr string) bool {
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}