Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tavern/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) {
)},
"/graphql": tavernhttp.Endpoint{Handler: newGraphQLHandler(client, git)},
"/c2.C2/": tavernhttp.Endpoint{Handler: newGRPCHandler(client)},
"/cdn/": tavernhttp.Endpoint{Handler: cdn.NewDownloadHandler(client)},
"/cdn/": tavernhttp.Endpoint{Handler: cdn.NewDownloadHandler(client, "/cdn/")},
"/cdn/upload": tavernhttp.Endpoint{Handler: cdn.NewUploadHandler(client)},
"/": tavernhttp.Endpoint{
Handler: www.NewHandler(httpLogger),
Expand Down
6 changes: 3 additions & 3 deletions tavern/internal/cdn/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package cdn
import (
"bytes"
"net/http"
"path/filepath"
"strings"

"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/ent/file"
Expand All @@ -18,12 +18,12 @@ const (
)

// NewDownloadHandler returns an HTTP handler responsible for downloading a file from the CDN.
func NewDownloadHandler(graph *ent.Client) http.Handler {
func NewDownloadHandler(graph *ent.Client, prefix string) http.Handler {
return errors.WrapHandler(func(w http.ResponseWriter, req *http.Request) error {
ctx := req.Context()

// Get the File name from the request URI
fileName := filepath.Base(req.URL.Path)
fileName := strings.TrimPrefix(req.URL.Path, prefix)
if fileName == "" || fileName == "." || fileName == "/" {
return ErrInvalidFileName
}
Expand Down
2 changes: 1 addition & 1 deletion tavern/internal/cdn/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestDownload(t *testing.T) {
func newDownloadTest(graph *ent.Client, req *http.Request, checks ...func(t *testing.T, fileContent []byte, err *errors.HTTP)) func(*testing.T) {
return func(t *testing.T) {
// Initialize Download Handler
handler := cdn.NewDownloadHandler(graph)
handler := cdn.NewDownloadHandler(graph, "/download/")

// Send request and record response
w := httptest.NewRecorder()
Expand Down