mirror of https://github.com/jlelse/GoBlog
parent
33e9d53a93
commit
d48f4f556a
9 changed files with 225 additions and 10 deletions
@ -0,0 +1,105 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"io" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"os" |
||||
"path/filepath" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func initTestHttpLogs(logFile string) (http.Handler, error) { |
||||
|
||||
app := &goBlog{ |
||||
cfg: &config{ |
||||
Server: &configServer{ |
||||
Logging: true, |
||||
LogFile: logFile, |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
err := app.initHTTPLog() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return app.logMiddleware(testHttpHandler()), nil |
||||
|
||||
} |
||||
|
||||
func testHttpHandler() http.Handler { |
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||
_, _ = rw.Write([]byte("Test")) |
||||
}) |
||||
} |
||||
|
||||
func Test_httpLogs(t *testing.T) { |
||||
|
||||
// Init
|
||||
|
||||
logFile := filepath.Join(t.TempDir(), "access.log") |
||||
handler, err := initTestHttpLogs(logFile) |
||||
|
||||
require.NoError(t, err) |
||||
|
||||
// Do fake request
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/testpath", nil) |
||||
rec := httptest.NewRecorder() |
||||
|
||||
handler.ServeHTTP(rec, req) |
||||
|
||||
// Check response
|
||||
|
||||
res := rec.Result() |
||||
resBody, _ := io.ReadAll(res.Body) |
||||
resBodyStr := string(resBody) |
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode) |
||||
assert.Contains(t, resBodyStr, "Test") |
||||
|
||||
// Check log
|
||||
|
||||
logBytes, err := os.ReadFile(logFile) |
||||
require.NoError(t, err) |
||||
|
||||
logString := string(logBytes) |
||||
assert.Contains(t, logString, "GET /testpath") |
||||
|
||||
} |
||||
|
||||
func Benchmark_httpLogs(b *testing.B) { |
||||
|
||||
// Init
|
||||
|
||||
logFile := filepath.Join(b.TempDir(), "access.log") |
||||
logHandler, err := initTestHttpLogs(logFile) |
||||
require.NoError(b, err) |
||||
|
||||
noLogHandler := testHttpHandler() |
||||
|
||||
// Run benchmarks
|
||||
|
||||
b.Run("With logging", func(b *testing.B) { |
||||
b.RunParallel(func(p *testing.PB) { |
||||
for p.Next() { |
||||
logHandler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/testpath", nil)) |
||||
} |
||||
}) |
||||
|
||||
}) |
||||
|
||||
b.Run("Without logging", func(b *testing.B) { |
||||
b.RunParallel(func(p *testing.PB) { |
||||
for p.Next() { |
||||
noLogHandler.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/testpath", nil)) |
||||
} |
||||
}) |
||||
}) |
||||
|
||||
} |
@ -0,0 +1,106 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"io" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"path/filepath" |
||||
"testing" |
||||
|
||||
"github.com/go-chi/chi/v5/middleware" |
||||
"github.com/justinas/alice" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func Test_privateMode(t *testing.T) { |
||||
|
||||
// Init
|
||||
|
||||
app := &goBlog{ |
||||
cfg: &config{ |
||||
Db: &configDb{ |
||||
File: filepath.Join(t.TempDir(), "db.db"), |
||||
}, |
||||
Server: &configServer{}, |
||||
PrivateMode: &configPrivateMode{ |
||||
Enabled: true, |
||||
}, |
||||
User: &configUser{ |
||||
Name: "Test", |
||||
Nick: "test", |
||||
Password: "testpw", |
||||
AppPasswords: []*configAppPassword{ |
||||
{ |
||||
Username: "testapp", |
||||
Password: "pw", |
||||
}, |
||||
}, |
||||
}, |
||||
DefaultBlog: "en", |
||||
Blogs: map[string]*configBlog{ |
||||
"en": { |
||||
Lang: "en", |
||||
}, |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
_ = app.initDatabase(false) |
||||
app.initComponents(false) |
||||
|
||||
handler := alice.New(middleware.WithValue(blogKey, "en"), app.privateModeHandler).ThenFunc(func(rw http.ResponseWriter, r *http.Request) { |
||||
_, _ = rw.Write([]byte("Awesome")) |
||||
}) |
||||
|
||||
// Test check
|
||||
|
||||
assert.True(t, app.isPrivate()) |
||||
|
||||
// Test successful request
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil) |
||||
req.SetBasicAuth("testapp", "pw") |
||||
rec := httptest.NewRecorder() |
||||
|
||||
handler.ServeHTTP(rec, req) |
||||
|
||||
res := rec.Result() |
||||
resBody, _ := io.ReadAll(res.Body) |
||||
resBodyStr := string(resBody) |
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode) |
||||
assert.Equal(t, "Awesome", resBodyStr) |
||||
|
||||
// Test unauthenticated request
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil) |
||||
rec = httptest.NewRecorder() |
||||
|
||||
handler.ServeHTTP(rec, req) |
||||
|
||||
res = rec.Result() |
||||
resBody, _ = io.ReadAll(res.Body) |
||||
resBodyStr = string(resBody) |
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode) // Not 401, to be compatible with some apps
|
||||
assert.NotEqual(t, "Awesome", resBodyStr) |
||||
assert.NotContains(t, resBodyStr, "Awesome") |
||||
assert.Contains(t, resBodyStr, "Login") |
||||
|
||||
// Disable private mode
|
||||
|
||||
app.cfg.PrivateMode.Enabled = false |
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil) |
||||
rec = httptest.NewRecorder() |
||||
|
||||
handler.ServeHTTP(rec, req) |
||||
|
||||
res = rec.Result() |
||||
resBody, _ = io.ReadAll(res.Body) |
||||
resBodyStr = string(resBody) |
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode) |
||||
assert.Equal(t, "Awesome", resBodyStr) |
||||
|
||||
} |
Loading…
Reference in new issue