summaryrefslogtreecommitdiff
path: root/internal/api/obsidian_test.go
blob: 35095949644b4cbfa460cc23bf9a188ea2fdf7f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package api

import (
	"context"
	"os"
	"path/filepath"
	"testing"
	"time"
)

// TestGetNotes_SymlinkSecurity verifies that GetNotes does not follow symlinks
// to prevent path traversal attacks
func TestGetNotes_SymlinkSecurity(t *testing.T) {
	// Create temporary directories
	tempDir := t.TempDir()
	vaultDir := filepath.Join(tempDir, "vault")
	outsideDir := filepath.Join(tempDir, "outside")

	// Create vault and outside directories
	if err := os.Mkdir(vaultDir, 0755); err != nil {
		t.Fatalf("Failed to create vault directory: %v", err)
	}
	if err := os.Mkdir(outsideDir, 0755); err != nil {
		t.Fatalf("Failed to create outside directory: %v", err)
	}

	// Create a valid markdown file inside the vault
	validFile := filepath.Join(vaultDir, "valid-note.md")
	validContent := "# Valid Note\n\nThis is a valid note inside the vault."
	if err := os.WriteFile(validFile, []byte(validContent), 0644); err != nil {
		t.Fatalf("Failed to create valid markdown file: %v", err)
	}

	// Create a secret file outside the vault
	secretFile := filepath.Join(outsideDir, "secret.txt")
	secretContent := "This is a secret file outside the vault that should not be accessible."
	if err := os.WriteFile(secretFile, []byte(secretContent), 0644); err != nil {
		t.Fatalf("Failed to create secret file: %v", err)
	}

	// Create a symlink inside the vault pointing to the secret file
	symlinkPath := filepath.Join(vaultDir, "symlink-note.md")
	if err := os.Symlink(secretFile, symlinkPath); err != nil {
		t.Skipf("Skipping test: unable to create symlink (may not be supported on this system): %v", err)
	}

	// Initialize Obsidian client
	client := NewObsidianClient(vaultDir)

	// Call GetNotes
	ctx := context.Background()
	notes, err := client.GetNotes(ctx, 10)
	if err != nil {
		t.Fatalf("GetNotes returned error: %v", err)
	}

	// Verify that only the valid note is returned
	if len(notes) != 1 {
		t.Errorf("Expected 1 note, got %d notes", len(notes))
		for i, note := range notes {
			t.Logf("Note %d: %s (path: %s)", i, note.Filename, note.Path)
		}
		t.Fatalf("Test failed: symlink was followed or wrong number of notes returned")
	}

	// Verify the returned note is the valid one
	note := notes[0]
	if note.Filename != "valid-note.md" {
		t.Errorf("Expected filename 'valid-note.md', got '%s'", note.Filename)
	}
	if note.Title != "Valid Note" {
		t.Errorf("Expected title 'Valid Note', got '%s'", note.Title)
	}

	// Ensure the content does not contain the secret text
	if containsString(note.Content, "secret") {
		t.Errorf("Note content contains 'secret', which suggests symlink was followed: %s", note.Content)
	}
}

// TestGetNotes_BasicFunctionality tests basic GetNotes functionality
func TestGetNotes_BasicFunctionality(t *testing.T) {
	// Create temporary vault directory
	vaultDir := t.TempDir()

	// Create multiple markdown files with different modification times
	files := []struct {
		name    string
		content string
		delay   time.Duration
	}{
		{"oldest.md", "# Oldest Note\n\nThis is the oldest note.", 0},
		{"middle.md", "# Middle Note\n\nThis is a middle note.", 10 * time.Millisecond},
		{"newest.md", "# Newest Note\n\nThis is the newest note.", 20 * time.Millisecond},
	}

	for _, file := range files {
		time.Sleep(file.delay)
		path := filepath.Join(vaultDir, file.name)
		if err := os.WriteFile(path, []byte(file.content), 0644); err != nil {
			t.Fatalf("Failed to create file %s: %v", file.name, err)
		}
	}

	// Initialize Obsidian client
	client := NewObsidianClient(vaultDir)

	// Test with limit
	ctx := context.Background()
	notes, err := client.GetNotes(ctx, 2)
	if err != nil {
		t.Fatalf("GetNotes returned error: %v", err)
	}

	// Should return 2 most recent notes
	if len(notes) != 2 {
		t.Errorf("Expected 2 notes with limit, got %d", len(notes))
	}

	// Verify order (newest first)
	if len(notes) >= 2 {
		if notes[0].Filename != "newest.md" {
			t.Errorf("Expected first note to be 'newest.md', got '%s'", notes[0].Filename)
		}
		if notes[1].Filename != "middle.md" {
			t.Errorf("Expected second note to be 'middle.md', got '%s'", notes[1].Filename)
		}
	}

	// Test without limit
	allNotes, err := client.GetNotes(ctx, 0)
	if err != nil {
		t.Fatalf("GetNotes without limit returned error: %v", err)
	}

	if len(allNotes) != 3 {
		t.Errorf("Expected 3 notes without limit, got %d", len(allNotes))
	}
}

// containsString checks if haystack contains needle (case-insensitive)
func containsString(haystack, needle string) bool {
	return len(haystack) > 0 && len(needle) > 0 &&
		(haystack == needle || len(haystack) >= len(needle) &&
		hasSubstring(haystack, needle))
}

func hasSubstring(s, substr string) bool {
	for i := 0; i <= len(s)-len(substr); i++ {
		if s[i:i+len(substr)] == substr {
			return true
		}
	}
	return false
}