diff --git a/cmd/xgql/main.go b/cmd/xgql/main.go index 3ba686f..4c9a9af 100644 --- a/cmd/xgql/main.go +++ b/cmd/xgql/main.go @@ -38,6 +38,7 @@ import ( google "github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/gorilla/websocket" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -211,7 +212,12 @@ func main() { srv := handler.New(generated.NewExecutableSchema(generated.Config{Resolvers: resolvers.New(ca)})) srv.AddTransport(transport.Websocket{ + Upgrader: websocket.Upgrader{ + // Enable per message compression. + EnableCompression: true, + }, PingPongInterval: 10 * time.Second, + InitFunc: auth.WebsocketInit, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) @@ -229,7 +235,6 @@ func main() { srv.Use(opentelemetry.Tracer{}) srv.Use(apollotracing.Tracer{}) srv.Use(live_query.LiveQuery{}) - srv.AroundOperations(auth.OperationMiddleware) rt.Handle("/query", otelhttp.NewHandler(srv, "/query")) rt.Handle("/metrics", promhttp.Handler()) diff --git a/go.mod b/go.mod index 437a0f4..4e39e11 100644 --- a/go.mod +++ b/go.mod @@ -66,9 +66,10 @@ require ( github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.11.0 // indirect - github.com/gorilla/websocket v1.5.0 // indirect + github.com/gorilla/websocket v1.5.0 github.com/hashicorp/golang-lru/v2 v2.0.3 // indirect github.com/imdario/mergo v0.3.16 // indirect + github.com/josephburnett/jd v1.7.1 github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect diff --git a/go.sum b/go.sum index 485a763..52baa55 100644 --- a/go.sum +++ b/go.sum @@ -479,6 +479,8 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/joefitzgerald/rainbow-reporter v0.1.0/go.mod h1:481CNgqmVHQZzdIbN52CupLJyoVwB10FQ/IQlF1pdL8= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/josephburnett/jd v1.7.1 h1:oXBPMS+SNnILTMGj1fWLK9pexpeJUXtbVFfRku/PjBU= +github.com/josephburnett/jd v1.7.1/go.mod h1:R8ZnZnLt2D4rhW4NvBc/USTo6mzyNT6fYNIIWOJA9GY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 6e8c72f..d4781d8 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -21,7 +21,6 @@ import ( "net/http" "strings" - "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/transport" "k8s.io/client-go/rest" ) @@ -143,27 +142,24 @@ func Middleware(next http.Handler) http.Handler { }) } -func OperationMiddleware(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { - if initPayload := transport.GetInitPayload(ctx); initPayload != nil { - r := &http.Request{ - Header: make(http.Header), - } - for k := range initPayload { - s := initPayload.GetString(k) - if s == "" { - continue - } - r.Header.Add(k, s) +func WebsocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { + r := &http.Request{ + Header: make(http.Header), + } + for k := range initPayload { + s := initPayload.GetString(k) + if s == "" { + continue } - bu, bp, _ := r.BasicAuth() - ctx = context.WithValue(ctx, key, Credentials{ - BasicUsername: bu, - BasicPassword: bp, - BearerToken: ExtractBearerToken(r), - Impersonate: ExtractImpersonation(r), - }) + r.Header.Add(k, s) } - return next(ctx) + bu, bp, _ := r.BasicAuth() + return context.WithValue(ctx, key, Credentials{ + BasicUsername: bu, + BasicPassword: bp, + BearerToken: ExtractBearerToken(r), + Impersonate: ExtractImpersonation(r), + }), nil } // FromContext extracts credentials from the supplied context. diff --git a/internal/clients/live_query.go b/internal/clients/live_query.go index 328f969..54e7538 100644 --- a/internal/clients/live_query.go +++ b/internal/clients/live_query.go @@ -127,7 +127,7 @@ func (c *liveQueryCache) trackObjectList(ctx context.Context, list client.Object var r toolscache.ResourceEventHandlerRegistration r, err = i.AddEventHandler(toolscache.FilteringResourceEventHandler{ FilterFunc: func(obj interface{}) bool { - if live_query.IsLive(ctx) { + if !live_query.IsLive(ctx) { _ = i.RemoveEventHandler(r) return false } diff --git a/internal/graph/extensions/live_query/json_patch.go b/internal/graph/extensions/live_query/json_patch.go index 1fedd4b..a2adc5e 100644 --- a/internal/graph/extensions/live_query/json_patch.go +++ b/internal/graph/extensions/live_query/json_patch.go @@ -16,12 +16,8 @@ package live_query import ( "encoding/json" - "fmt" - "reflect" - "strconv" - "strings" - "github.com/google/go-cmp/cmp" + jd "github.com/josephburnett/jd/lib" ) // Op is a JSON Patch operation. @@ -40,146 +36,46 @@ const ( // Operation is a single JSON Patch operation. type Operation struct { - Op Op `json:"op"` - Path string `json:"path"` - From string `json:"from,omitempty"` - Value interface{} `json:"value,omitempty"` + Op Op `json:"op"` + Path string `json:"path"` + From string `json:"from,omitempty"` + Value any `json:"value,omitempty"` } -// JSONPatchReporter is a simple custom reporter that records the operations needed to -// transform one value into another. -type JSONPatchReporter struct { - path cmp.Path - patch []Operation -} - -// PushStep implements cmp.Reporter. -func (r *JSONPatchReporter) PushStep(ps cmp.PathStep) { - r.path = append(r.path, ps) -} - -// Report implements cmp.Reporter. -// we add any unequal operations to the list of operations. -func (r *JSONPatchReporter) Report(rs cmp.Result) { - if rs.Equal() { - return +// CreateJSONPatch creates a JSON patch between two json values. +// TODO(avalanche123): add tests for json patch generation. +func CreateJSONPatch(x, y string) ([]Operation, error) { + xn, err := jd.ReadJsonString(x) + if err != nil { + return nil, err } - assert(len(r.path) > 0) - ps := r.path.Last() - vx, vy := ps.Values() - var op Operation - switch s := ps.(type) { - case cmp.MapIndex, cmp.TypeAssertion: - switch { - // value did not exist. - case !vx.IsValid(): - op = Operation{ - Op: Add, - Path: pathAsJSONPointer(r.path), - Value: vy.Interface(), - } - // value does not exist. - case !vy.IsValid(): - op = Operation{ - Op: Remove, - Path: pathAsJSONPointer(r.path), - } - // value replaced. - default: - op = Operation{ - Op: Replace, - Path: pathAsJSONPointer(r.path), - Value: vy.Interface(), - } + yn, err := jd.ReadJsonString(y) + if err != nil { + return nil, err + } + raw, err := xn.Diff(yn).RenderPatch() + if err != nil { + return nil, err + } + var patch []Operation + if err := json.Unmarshal([]byte(raw), &patch); err != nil { + return nil, err + } + for i := 1; i < len(patch); i++ { + // previous operation and operation + rm, ad := patch[i-1], patch[i] + if rm.Path != ad.Path { + continue } - case cmp.SliceIndex: - kx, ky := s.SplitKeys() - switch { - // value was updated. - case kx == ky: - op = Operation{ + // coalesce remove and add into a replace + if rm.Op == Remove && ad.Op == Add { + patch[i] = Operation{ + Path: ad.Path, Op: Replace, - Path: pathAsJSONPointer(r.path), - Value: vy.Interface(), - } - // value did not exist before. - case kx == -1: - op = Operation{ - Op: Add, - Path: pathAsJSONPointer(r.path[:len(r.path)-1]) + "/" + strconv.Itoa(ky), - Value: vy.Interface(), - } - // value was removed. - case ky == -1: - op = Operation{ - Op: Remove, - Path: pathAsJSONPointer(r.path[:len(r.path)-1]) + "/" + strconv.Itoa(kx), - } - // value was moved. - default: - op = Operation{ - Op: Move, - Path: pathAsJSONPointer(r.path[:len(r.path)-1]) + "/" + strconv.Itoa(ky), - From: pathAsJSONPointer(r.path[:len(r.path)-1]) + "/" + strconv.Itoa(kx), + Value: ad.Value, } + patch = append(patch[:i-1], patch[i:]...) } - default: - panic(fmt.Sprintf("unknown path step type %T", s)) - } - r.patch = append(r.patch, op) -} - -// PopStep implements cmp.Reporter. -func (r *JSONPatchReporter) PopStep() { - r.path = r.path[:len(r.path)-1] -} - -func (r *JSONPatchReporter) GetPatch() []Operation { - return r.patch -} - -var jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") - -func pathAsJSONPointer(path cmp.Path) string { - var sb strings.Builder - for _, ps := range path { - switch s := ps.(type) { - case cmp.MapIndex: - // json must have string keys. - assert(s.Key().Kind() == reflect.String) - sb.WriteString("/") - sb.WriteString(jsonPointerEscaper.Replace(s.Key().String())) - case cmp.SliceIndex: - // split keys for slices must be handled at a higher level. - assert(s.Key() >= 0) - sb.WriteString("/") - sb.WriteString(strconv.Itoa(s.Key())) - } - } - return sb.String() -} - -func assert(ok bool) { - if !ok { - panic("assertion failure") - } -} - -func parseJSON(in []byte) (out any) { - if err := json.Unmarshal(in, &out); err != nil { - panic(err) // should never occur given previous filter to ensure valid JSON - } - return out -} - -// CreateJSONPatch creates a JSON patch between two json values. -func CreateJSONPatch(x, y []byte) ([]Operation, error) { - if !json.Valid(x) || !json.Valid(y) { - return nil, fmt.Errorf("invalid JSON") - } - r := &JSONPatchReporter{} - if cmp.Equal(parseJSON(x), parseJSON(y), cmp.Reporter(r)) { - return nil, nil } - return r.GetPatch(), nil + return patch, nil } diff --git a/internal/graph/extensions/live_query/live_query.go b/internal/graph/extensions/live_query/live_query.go index cddc97c..2a74434 100644 --- a/internal/graph/extensions/live_query/live_query.go +++ b/internal/graph/extensions/live_query/live_query.go @@ -15,10 +15,10 @@ package live_query import ( - "bytes" "context" _ "embed" "fmt" + "strings" "github.com/99designs/gqlgen/codegen" "github.com/99designs/gqlgen/codegen/config" @@ -193,7 +193,7 @@ func (l LiveQuery) InterceptOperation(ctx context.Context, next graphql.Operatio ctx, cancel := context.WithCancel(ctx) handler := next(ctx) var ( - prevData bytes.Buffer + prevData strings.Builder revision int ) return func(ctx context.Context) *graphql.Response { @@ -206,7 +206,7 @@ func (l LiveQuery) InterceptOperation(ctx context.Context, next graphql.Operatio data := resp.Data // Compare new data with previous response. if prevData.Len() > 0 { - diff, err := CreateJSONPatch(prevData.Bytes(), resp.Data) + diff, err := CreateJSONPatch(prevData.String(), string(data)) if err != nil { cancel() panic(err) diff --git a/internal/graph/extensions/live_query/runtime.go b/internal/graph/extensions/live_query/runtime.go index 68f2797..ab532dc 100644 --- a/internal/graph/extensions/live_query/runtime.go +++ b/internal/graph/extensions/live_query/runtime.go @@ -16,6 +16,7 @@ package live_query import ( "context" + "sync" "sync/atomic" ) @@ -26,6 +27,9 @@ import ( type liveQuery struct { doneCh <-chan struct{} hasChanges uint32 + + mu sync.Mutex + cond *sync.Cond } // HasChangesFn is a func that can be used to check if live query needs to be @@ -41,12 +45,22 @@ var liveQueryCtxKey = liveQueryKey{} // live query resolver to set up periodic live query refresh if changes occurred. func WithLiveQuery(ctx context.Context) (context.Context, HasChangesFn) { lq := &liveQuery{doneCh: ctx.Done()} + lq.cond = sync.NewCond(&lq.mu) return context.WithValue(ctx, liveQueryCtxKey, lq), func() bool { - return atomic.CompareAndSwapUint32(&lq.hasChanges, 1, 0) + if atomic.CompareAndSwapUint32(&lq.hasChanges, 1, 0) { + return true + } + lq.mu.Lock() + defer lq.mu.Unlock() + for !atomic.CompareAndSwapUint32(&lq.hasChanges, 1, 0) { + lq.cond.Wait() + } + return true } } // IsLive returns true if this is a live query context and query is active. +// TODO(avalanche123): add tests. func IsLive(ctx context.Context) bool { if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { select { @@ -60,8 +74,10 @@ func IsLive(ctx context.Context) bool { } // NotifyChanged notifies live query of a change. +// TODO(avalanche123): add tests. func NotifyChanged(ctx context.Context) { if lq, ok := ctx.Value(liveQueryCtxKey).(*liveQuery); ok { atomic.StoreUint32(&lq.hasChanges, 1) + lq.cond.Broadcast() } }