🎣 Open-Source Phishing Toolkit
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
6.2 KiB

  1. package middleware
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. "github.com/gophish/gophish/config"
  8. ctx "github.com/gophish/gophish/context"
  9. "github.com/gophish/gophish/models"
  10. )
  11. var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  12. w.Write([]byte("success"))
  13. })
  14. type testContext struct {
  15. apiKey string
  16. }
  17. func setupTest(t *testing.T) *testContext {
  18. conf := &config.Config{
  19. DBName: "sqlite3",
  20. DBPath: ":memory:",
  21. MigrationsPath: "../db/db_sqlite3/migrations/",
  22. }
  23. err := models.Setup(conf)
  24. if err != nil {
  25. t.Fatalf("Failed creating database: %v", err)
  26. }
  27. // Get the API key to use for these tests
  28. u, err := models.GetUser(1)
  29. if err != nil {
  30. t.Fatalf("error getting user: %v", err)
  31. }
  32. ctx := &testContext{}
  33. ctx.apiKey = u.ApiKey
  34. return ctx
  35. }
  36. // MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP
  37. // status code
  38. type MiddlewarePermissionTest map[string]int
  39. // TestEnforceViewOnly ensures that only users with the ModifyObjects
  40. // permission have the ability to send non-GET requests.
  41. func TestEnforceViewOnly(t *testing.T) {
  42. setupTest(t)
  43. permissionTests := map[string]MiddlewarePermissionTest{
  44. models.RoleAdmin: MiddlewarePermissionTest{
  45. http.MethodGet: http.StatusOK,
  46. http.MethodHead: http.StatusOK,
  47. http.MethodOptions: http.StatusOK,
  48. http.MethodPost: http.StatusOK,
  49. http.MethodPut: http.StatusOK,
  50. http.MethodDelete: http.StatusOK,
  51. },
  52. models.RoleUser: MiddlewarePermissionTest{
  53. http.MethodGet: http.StatusOK,
  54. http.MethodHead: http.StatusOK,
  55. http.MethodOptions: http.StatusOK,
  56. http.MethodPost: http.StatusOK,
  57. http.MethodPut: http.StatusOK,
  58. http.MethodDelete: http.StatusOK,
  59. },
  60. }
  61. for r, checks := range permissionTests {
  62. role, err := models.GetRoleBySlug(r)
  63. if err != nil {
  64. t.Fatalf("error getting role by slug: %v", err)
  65. }
  66. for method, expected := range checks {
  67. req := httptest.NewRequest(method, "/", nil)
  68. response := httptest.NewRecorder()
  69. req = ctx.Set(req, "user", models.User{
  70. Role: role,
  71. RoleID: role.ID,
  72. })
  73. EnforceViewOnly(successHandler).ServeHTTP(response, req)
  74. got := response.Code
  75. if got != expected {
  76. t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
  77. }
  78. }
  79. }
  80. }
  81. func TestRequirePermission(t *testing.T) {
  82. setupTest(t)
  83. middleware := RequirePermission(models.PermissionModifySystem)
  84. handler := middleware(successHandler)
  85. permissionTests := map[string]int{
  86. models.RoleUser: http.StatusForbidden,
  87. models.RoleAdmin: http.StatusOK,
  88. }
  89. for role, expected := range permissionTests {
  90. req := httptest.NewRequest(http.MethodGet, "/", nil)
  91. response := httptest.NewRecorder()
  92. // Test that with the requested permission, the request succeeds
  93. role, err := models.GetRoleBySlug(role)
  94. if err != nil {
  95. t.Fatalf("error getting role by slug: %v", err)
  96. }
  97. req = ctx.Set(req, "user", models.User{
  98. Role: role,
  99. RoleID: role.ID,
  100. })
  101. handler.ServeHTTP(response, req)
  102. got := response.Code
  103. if got != expected {
  104. t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
  105. }
  106. }
  107. }
  108. func TestRequireAPIKey(t *testing.T) {
  109. setupTest(t)
  110. req := httptest.NewRequest(http.MethodGet, "/", nil)
  111. req.Header.Set("Content-Type", "application/json")
  112. response := httptest.NewRecorder()
  113. // Test that making a request without an API key is denied
  114. RequireAPIKey(successHandler).ServeHTTP(response, req)
  115. expected := http.StatusUnauthorized
  116. got := response.Code
  117. if got != expected {
  118. t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
  119. }
  120. }
  121. func TestCORSHeaders(t *testing.T) {
  122. setupTest(t)
  123. req := httptest.NewRequest(http.MethodOptions, "/", nil)
  124. response := httptest.NewRecorder()
  125. RequireAPIKey(successHandler).ServeHTTP(response, req)
  126. expected := "POST, GET, OPTIONS, PUT, DELETE"
  127. got := response.Result().Header.Get("Access-Control-Allow-Methods")
  128. if got != expected {
  129. t.Fatalf("incorrect cors options received. expected %s got %s", expected, got)
  130. }
  131. }
  132. func TestInvalidAPIKey(t *testing.T) {
  133. setupTest(t)
  134. req := httptest.NewRequest(http.MethodGet, "/", nil)
  135. query := req.URL.Query()
  136. query.Set("api_key", "bogus-api-key")
  137. req.URL.RawQuery = query.Encode()
  138. req.Header.Set("Content-Type", "application/json")
  139. response := httptest.NewRecorder()
  140. RequireAPIKey(successHandler).ServeHTTP(response, req)
  141. expected := http.StatusUnauthorized
  142. got := response.Code
  143. if got != expected {
  144. t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
  145. }
  146. }
  147. func TestBearerToken(t *testing.T) {
  148. testCtx := setupTest(t)
  149. req := httptest.NewRequest(http.MethodGet, "/", nil)
  150. req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey))
  151. req.Header.Set("Content-Type", "application/json")
  152. response := httptest.NewRecorder()
  153. RequireAPIKey(successHandler).ServeHTTP(response, req)
  154. expected := http.StatusOK
  155. got := response.Code
  156. if got != expected {
  157. t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
  158. }
  159. }
  160. func TestPasswordResetRequired(t *testing.T) {
  161. req := httptest.NewRequest(http.MethodGet, "/", nil)
  162. req = ctx.Set(req, "user", models.User{
  163. PasswordChangeRequired: true,
  164. })
  165. response := httptest.NewRecorder()
  166. RequireLogin(successHandler).ServeHTTP(response, req)
  167. gotStatus := response.Code
  168. expectedStatus := http.StatusTemporaryRedirect
  169. if gotStatus != expectedStatus {
  170. t.Fatalf("incorrect status code received. expected %d got %d", expectedStatus, gotStatus)
  171. }
  172. expectedLocation := "/reset_password?next=%2F"
  173. gotLocation := response.Header().Get("Location")
  174. if gotLocation != expectedLocation {
  175. t.Fatalf("incorrect location header received. expected %s got %s", expectedLocation, gotLocation)
  176. }
  177. }
  178. func TestApplySecurityHeaders(t *testing.T) {
  179. expected := map[string]string{
  180. "Content-Security-Policy": "frame-ancestors 'none';",
  181. "X-Frame-Options": "DENY",
  182. }
  183. req := httptest.NewRequest(http.MethodGet, "/", nil)
  184. response := httptest.NewRecorder()
  185. ApplySecurityHeaders(successHandler).ServeHTTP(response, req)
  186. for header, value := range expected {
  187. got := response.Header().Get(header)
  188. if got != value {
  189. t.Fatalf("incorrect security header received for %s: expected %s got %s", header, value, got)
  190. }
  191. }
  192. }