Skip to content

Commit

Permalink
add helper for ordered parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed Jun 28, 2023
1 parent eaf3001 commit 668c800
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 0 deletions.
91 changes: 91 additions & 0 deletions maps/ordered_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package mapsutil

import (
sliceutil "github.com/projectdiscovery/utils/slice"
"golang.org/x/exp/maps"
)

// OrderedMap is a map that preserves the order of elements
type OrderedMap[k comparable, v any] struct {
keys []k
m map[k]v
}

// Set sets a value in the OrderedMap (if the key already exists, it will be overwritten)
func (o *OrderedMap[k, v]) Set(key k, value v) {
o.m[key] = value
o.keys = append(o.keys, key)
}

// Get gets a value from the OrderedMap
func (o *OrderedMap[k, v]) Get(key k) (v, bool) {
value, ok := o.m[key]
return value, ok
}

// Iterate iterates over the OrderedMap
func (o *OrderedMap[k, v]) Iterate(f func(key k, value v) bool) {
for _, key := range o.keys {
if !f(key, o.m[key]) {
break
}
}
}

// GetKeys returns the keys of the OrderedMap
func (o *OrderedMap[k, v]) GetKeys() []k {
return o.keys
}

// Has checks if the OrderedMap has the provided key
func (o *OrderedMap[k, v]) Has(key k) bool {
_, ok := o.m[key]
return ok
}

// IsEmpty checks if the OrderedMap is empty
func (o *OrderedMap[k, v]) IsEmpty() bool {
return len(o.keys) == 0
}

// Clone returns clone of OrderedMap
func (o *OrderedMap[k, v]) Clone() *OrderedMap[k, v] {
return &OrderedMap[k, v]{
keys: sliceutil.Clone(o.keys),
m: maps.Clone(o.m),
}
}

// GetByIndex gets a value from the OrderedMap by index
func (o *OrderedMap[k, v]) GetByIndex(index int) (v, bool) {
var t v
if index < 0 || index >= len(o.keys) {
return t, false
}
key := o.keys[index]
return o.m[key], true
}

// Delete deletes a value from the OrderedMap
func (o *OrderedMap[k, v]) Delete(key k) {
delete(o.m, key)
for i, k := range o.keys {
if k == key {
o.keys = append(o.keys[:i], o.keys[i+1:]...)
break
}
}
}

// Len returns the length of the OrderedMap
func (o *OrderedMap[k, v]) Len() int {
return len(o.keys)
}

// NewOrderedMap creates a new OrderedMap
func NewOrderedMap[k comparable, v any]() *OrderedMap[k, v] {
return &OrderedMap[k, v]{
keys: []k{},
m: map[k]v{},
}
}
72 changes: 72 additions & 0 deletions maps/ordered_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mapsutil

import (
"fmt"
"strconv"
"testing"
)

func TestOrderedMapBasic(t *testing.T) {
m := NewOrderedMap[string, string]()
m.Set("test", "test")
if m.IsEmpty() {
t.Fatal("ordered map is empty")
}
if !m.Has("test") {
t.Fatal("ordered map doesn't have test key")
}
if m.Has("test2") {
t.Fatal("ordered map has test2 key")
}
if val, ok := m.Get("test"); !ok || val != "test" {
t.Fatal("ordered map get test key doesn't return test value")
}
if m.GetKeys()[0] != "test" {
t.Fatal("ordered map get keys doesn't return test key")
}
if val, ok := m.GetByIndex(0); !ok || val != "test" {
t.Fatal("ordered map get by index doesn't return test key")
}
m.Delete("test")
if !m.IsEmpty() {
t.Fatal("ordered map is not empty after delete")
}
}

func TestOrderedMap(t *testing.T) {
m := NewOrderedMap[string, string]()
for i := 0; i < 110; i++ {
m.Set(strconv.Itoa(i), fmt.Sprintf("value-%d", i))
}

// iterate and validate order
i := 0
m.Iterate(func(key string, value string) bool {
if key != strconv.Itoa(i) {
t.Fatal("ordered map iterate order is not correct")
}
i++
return true
})

// validate get by index
for i := 0; i < 100; i++ {
if val, ok := m.GetByIndex(i); !ok || val != fmt.Sprintf("value-%d", i) {
t.Fatal("ordered map get by index doesn't return correct value")
}
}

// random delete and validate order
deleteElements := []int{0, 10, 20, 30, 40, 50, 60, 70, 80, 90}
for _, i := range deleteElements {
m.Delete(strconv.Itoa(i))
}

// validate elements after delete
for k, i := range deleteElements {
if val, ok := m.GetByIndex(i); !ok || val != fmt.Sprintf("value-%d", i+k+1) {
t.Logf("order mismatch after delete got: index: %d, value: %s, exists: %v", i, val, ok)
}
}

}
126 changes: 126 additions & 0 deletions url/orderedparams.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package urlutil

import (
"bytes"
"strings"

mapsutil "github.com/projectdiscovery/utils/maps"
)

// Only difference between OrderedParams and Params is that
// OrderedParams preserves order of parameters everythign else is same

// OrderedParams is a map that preserves the order of elements
type OrderedParams struct {
om *mapsutil.OrderedMap[string, []string]
}

// NewOrderedParams creates a new ordered params
func NewOrderedParams() *OrderedParams {
return &OrderedParams{
om: mapsutil.NewOrderedMap[string, []string](),
}
}

// Add Parameters to store
func (o *OrderedParams) Add(key string, value ...string) {
if arr, ok := o.om.Get(key); ok && len(arr) > 0 {
if len(value) != 0 {
o.om.Set(key, append(arr, value...))
}
} else {
o.om.Set(key, value)
}
}

// Set sets the key to value and replaces if already exists
func (o *OrderedParams) Set(key string, value string) {
o.om.Set(key, []string{value})
}

// Get returns first value of given key
func (o *OrderedParams) Get(key string) string {
val, ok := o.om.Get(key)
if !ok || len(val) == 0 {
return ""
}
return val[0]
}

// Has returns if given key exists
func (o *OrderedParams) Has(key string) bool {
return o.om.Has(key)
}

// Del deletes values associated with key
func (o *OrderedParams) Del(key string) {
o.om.Delete(key)
}

// Merges given paramset into existing one with base as priority
func (o *OrderedParams) Merge(raw string) {

}

// Encode returns encoded parameters by preserving order
func (o *OrderedParams) Encode() string {
if o.om.IsEmpty() {
return ""
}
var buf strings.Builder
for _, k := range o.om.GetKeys() {
vs, _ := o.om.Get(k)
keyEscaped := ParamEncode(k)
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(keyEscaped)
value := ParamEncode(v)
// donot specify = if parameter has no value (reference: nuclei-templates)
if value != "" {
buf.WriteRune('=')
buf.WriteString(value)
}
}
}
return buf.String()
}

// Decode is opposite of Encode() where ("bar=baz&foo=quux") is parsed
// Parameters are loosely parsed to allow any scenario
func (o *OrderedParams) Decode(raw string) {
if o.om == nil {
o.om = mapsutil.NewOrderedMap[string, []string]()
}
arr := []string{}
var tbuff bytes.Buffer
for _, v := range raw {
switch v {
case '&':
arr = append(arr, tbuff.String())
tbuff.Reset()
case ';':
if AllowLegacySeperator {
arr = append(arr, tbuff.String())
tbuff.Reset()
continue
}
tbuff.WriteRune(v)
default:
tbuff.WriteRune(v)
}
}
if tbuff.Len() > 0 {
arr = append(arr, tbuff.String())
}

for _, pair := range arr {
d := strings.SplitN(pair, "=", 2)
if len(d) == 2 {
o.Add(d[0], d[1])
} else if len(d) == 1 {
o.Add(d[0], "")
}
}
}
55 changes: 55 additions & 0 deletions url/orderedparams_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package urlutil

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/require"
)

func TestOrderedParam(t *testing.T) {
p := NewOrderedParams()
p.Add("sqli", "1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)")
p.Add("xss", "<script>alert('XSS')</script>")
p.Add("xssiwthspace", "<svg id=alert(1) onload=eval(id)>")
p.Add("jsprotocol", "javascript://alert(1)")
// Note keys are sorted
expected := "sqli=1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)&xss=<script>alert('XSS')</script>&xssiwthspace=<svg+id=alert(1)+onload=eval(id)>&jsprotocol=javascript://alert(1)"
require.Equalf(t, expected, p.Encode(), "failed to encode parameters expected %v but got %v", expected, p.Encode())
}

// TestOrderedParamIntegration preserves order of parameters
// while sending request to server (ref:https://github.com/projectdiscovery/nuclei/issues/3801)
func TestOrderedParamIntegration(t *testing.T) {
expected := "/?xss=<script>alert('XSS')</script>&sqli=1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)&jsprotocol=javascript://alert(1)&xssiwthspace=<svg+id=alert(1)+onload=eval(id)>"

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equalf(t, expected, r.RequestURI, "expected %v but got %v", expected, r.RequestURI)
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

p := NewOrderedParams()
p.Add("xss", "<script>alert('XSS')</script>")
p.Add("sqli", "1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)")
p.Add("jsprotocol", "javascript://alert(1)")
p.Add("xssiwthspace", "<svg id=alert(1) onload=eval(id)>")

url, err := url.Parse(srv.URL)
require.Nil(t, err)
url.RawQuery = p.Encode()
_, err = http.Get(url.String())
require.Nil(t, err)
}

func TestGetOrderedParams(t *testing.T) {
values := url.Values{}
values.Add("sqli", "1+AND+(SELECT+*+FROM+(SELECT(SLEEP(12)))nQIP)")
values.Add("xss", "<script>alert('XSS')</script>")
p := GetParams(values)
require.NotNilf(t, p, "expected params but got nil")
require.Equalf(t, p.Get("sqli"), values.Get("sqli"), "malformed or missing value for param sqli expected %v but got %v", values.Get("sqli"), p.Get("sqli"))
require.Equalf(t, p.Get("xss"), values.Get("xss"), "malformed or missing value for param xss expected %v but got %v", values.Get("xss"), p.Get("xss"))
}

0 comments on commit 668c800

Please sign in to comment.