Skip to content

Commit

Permalink
chore(dataacess): Preventing against NoSQL injection. (#97)
Browse files Browse the repository at this point in the history
* Trying to fix nosql injection

* Correcting the nosql injection pt2

* Additional validation

* Additional validation

* Additional validation

* Additional validation

* updating dependencies
  • Loading branch information
Jacobbrewer1 authored Mar 5, 2024
1 parent 7f9d9eb commit 099b468
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
18 changes: 16 additions & 2 deletions pkg/dataaccess/db_mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,21 @@ func (m *mongodbImpl) GetHistory(ctx context.Context, environment ...summary.Env
}

func (m *mongodbImpl) GetReport(ctx context.Context, id string) (*entities.PuppetReport, error) {
if id == "" {
return nil, errors.New("id cannot be empty")
}

collection := m.client.Database(mongoDatabase).Collection("reports")

// Start the prometheus metrics.
t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("get_report"))
defer t.ObserveDuration()

var report entities.PuppetReport
err := collection.FindOne(ctx, bson.M{"id": id}).Decode(&report)
err := collection.FindOne(ctx, bson.M{"id": bson.M{
"$eq": id,
"$ne": "", // This is to ensure that the id is not empty. AKA NOSQL injection.
}}).Decode(&report)
if errors.Is(err, mongo.ErrNoDocuments) {
return nil, ErrNotFound
} else if err != nil {
Expand All @@ -216,12 +223,19 @@ func (m *mongodbImpl) GetReport(ctx context.Context, id string) (*entities.Puppe
}

func (m *mongodbImpl) GetReports(ctx context.Context, fqdn string) ([]*entities.PuppetReportSummary, error) {
if fqdn == "" {
return nil, errors.New("fqdn cannot be empty")
}

collection := m.client.Database(mongoDatabase).Collection("reports")

// Start the prometheus metrics.
t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("get_reports"))

cursor, err := collection.Find(ctx, bson.M{"fqdn": fqdn})
cursor, err := collection.Find(ctx, bson.M{"fqdn": bson.M{
"$eq": fqdn,
"$ne": "", // This is to ensure that the fqdn is not empty. AKA NOSQL injection.
}})
if err != nil {
return nil, fmt.Errorf("error getting reports: %w", err)
}
Expand Down
19 changes: 8 additions & 11 deletions pkg/services/web/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@ import (
"github.com/Jacobbrewer1/puppet-summary/pkg/logging"
"github.com/Jacobbrewer1/puppet-summary/pkg/request"
"github.com/gorilla/mux"
"github.com/oapi-codegen/runtime"
)

func (s service) nodeFqdnHandler(w http.ResponseWriter, r *http.Request) {
nodeFqdn, ok := mux.Vars(r)["node_fqdn"]
if !ok {
// Respond with 400 bad request.
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(request.NewMessage("No node fqdn provided")); err != nil {
slog.Warn("Error encoding response", slog.String(logging.KeyError, err.Error()))
}
return
} else if nodeFqdn == "" {
// Respond with 400 bad request.
// ------------- Path parameter "node_fqdn" -------------
var nodeFqdn string

err := runtime.BindStyledParameterWithOptions("simple", "node_fqdn", mux.Vars(r)["node_fqdn"], &nodeFqdn, runtime.BindStyledParameterOptions{Explode: false, Required: true})
if err != nil {
slog.Error("Error binding path parameter", slog.String(logging.KeyError, err.Error()))
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(request.NewMessage("No node fqdn provided")); err != nil {
if err := json.NewEncoder(w).Encode(request.NewMessage("Error binding path parameter")); err != nil {
slog.Warn("Error encoding response", slog.String(logging.KeyError, err.Error()))
}
return
Expand Down
18 changes: 8 additions & 10 deletions pkg/services/web/reports.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,24 @@ import (
"github.com/Jacobbrewer1/puppet-summary/pkg/request"
"github.com/Jacobbrewer1/puppet-summary/pkg/services/parser"
"github.com/gorilla/mux"
"github.com/oapi-codegen/runtime"
)

func (s service) reportIDHandler(w http.ResponseWriter, r *http.Request) {
id, ok := mux.Vars(r)["report_id"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(request.NewMessage("No report ID provided")); err != nil {
slog.Warn("Error encoding response", slog.String(logging.KeyError, err.Error()))
}
return
} else if id == "" {
// ------------- Path parameter "reportId" -------------
var reportId string
err := runtime.BindStyledParameterWithOptions("simple", "report_id", mux.Vars(r)["report_id"], &reportId, runtime.BindStyledParameterOptions{Explode: false, Required: true})
if err != nil {
slog.Error("Error binding path parameter", slog.String(logging.KeyError, err.Error()))
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(request.NewMessage("Invalid report ID provided")); err != nil {
if err := json.NewEncoder(w).Encode(request.NewMessage("Error binding path parameter")); err != nil {
slog.Warn("Error encoding response", slog.String(logging.KeyError, err.Error()))
}
return
}

// Get the report from the database.
rep, err := s.db.GetReport(r.Context(), id)
rep, err := s.db.GetReport(r.Context(), reportId)
if errors.Is(err, dataaccess.ErrNotFound) {
w.WriteHeader(http.StatusNotFound)
if err := json.NewEncoder(w).Encode(request.NewMessage("Report not found")); err != nil {
Expand Down

0 comments on commit 099b468

Please sign in to comment.