Skip to content

Commit

Permalink
Merge pull request #358 from bavix/strict-method-mode
Browse files Browse the repository at this point in the history
[3.x] lower method name
  • Loading branch information
rez1dent3 committed Jul 20, 2024
2 parents d9ecb1c + 8833fc0 commit 83d4717
Show file tree
Hide file tree
Showing 13 changed files with 684 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ jobs:
HTTP_PORT: 6000
with:
entrypoint: example/ms/entrypoint.sh
- name: Run strict mode example
uses: ./
with:
entrypoint: example/strictmode/entrypoint.sh
11 changes: 0 additions & 11 deletions deployments/docker-compose/docker-compose.infra.yml

This file was deleted.

6 changes: 6 additions & 0 deletions deployments/docker-compose/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,9 @@ services:
entrypoint: example/well_known_types/entrypoint.sh
volumes:
- ./../../protogen/example/well_known_types:/go/src/github.com/bavix/gripmock/protogen/example/well_known_types
strict-mode:
image: bavix/gripmock:latest
entrypoint: example/strictmode/entrypoint.sh
volumes:
- ./../../protogen/example/strictmode:/go/src/github.com/bavix/gripmock/protogen/example/strictmode

52 changes: 52 additions & 0 deletions example/strictmode/client/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

import (
"context"
"log"
"os"
"time"

grpcinterceptors "github.com/gripmock/grpc-interceptors"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

strictmode "github.com/bavix/gripmock/protogen/example/strictmode"
)

//nolint:mnd
func main() {
// Set up a connection to the server.
conn, err := grpc.NewClient("localhost:4770",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithChainUnaryInterceptor(grpcinterceptors.UnaryTimeoutInterceptor(5*time.Second)),
grpc.WithChainStreamInterceptor(grpcinterceptors.StreamTimeoutInterceptor(5*time.Second)))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()

c := strictmode.NewGripMockClient(conn)

// Contact the server and print out its response.
name := "GripMock Request"
if len(os.Args) > 1 {
name = os.Args[1]
}
r1, err := c.SayLowerHello(context.Background(), &strictmode.SayLowerHelloRequest{Name: name}, grpc.WaitForReady(true))
if err != nil {
log.Fatalf("error from grpc: %v", err)
}

if r1.GetMessage() != "ok" {
log.Fatalf("message is not ok: %s", r1.GetMessage())
}

r2, err := c.SayTitleHello(context.Background(), &strictmode.SayTitleHelloRequest{Name: name}, grpc.WaitForReady(true))
if err != nil {
log.Fatalf("error from grpc: %v", err)
}

if r2.GetMessage() != "OK" {
log.Fatalf("message is not ok: %s", r2.GetMessage())
}
}
12 changes: 12 additions & 0 deletions example/strictmode/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env sh

# this file is used by .github/workflows/integration-test.yml

STRICT_METHOD_TITLE=false gripmock \
--stub=example/strictmode/stub \
example/strictmode/method.proto &

# wait for generated files to be available and gripmock is up
gripmock check --silent --timeout=30s

go run example/strictmode/client/*.go
24 changes: 24 additions & 0 deletions example/strictmode/method.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";

package strictmode;

service GripMock {
rpc SayTitleHello (SayTitleHelloRequest) returns (SayTitleHelloReply);
rpc sayLowerHello (sayLowerHelloRequest) returns (sayLowerHelloReply);
}

message sayLowerHelloRequest {
string name = 1;
}

message sayLowerHelloReply {
string message = 1;
}

message SayTitleHelloRequest {
string name = 1;
}

message SayTitleHelloReply {
string message = 1;
}
9 changes: 9 additions & 0 deletions example/strictmode/stub/lower.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
service: GripMock
method: sayLowerHello
input:
ignoreArrayOrder: true
equals:
name: "GripMock Request"
output:
data:
message: "ok"
9 changes: 9 additions & 0 deletions example/strictmode/stub/title.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
service: GripMock
method: SayTitleHello
input:
ignoreArrayOrder: true
equals:
name: "GripMock Request"
output:
data:
message: "OK"
51 changes: 38 additions & 13 deletions protoc-gen-gripmock/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,28 @@ type Service struct {
Methods []methodTemplate `json:"methods"`
}

// methodTemplate represents a method in a gRPC service.
// methodTemplate represents a method in a gRPC service.
type methodTemplate struct {
// SvcPackage is the package name of the service.
// For example, if the service name is "Greeter", SvcPackage would be "github.com/bavix/gripmock/protogen/example/Greeter".
SvcPackage string `json:"svc_package"`
// Name is the name of the method.
Name string `json:"name"`
// RpcName is the name of the RPC method, without the "Greeter/" prefix.
// For example, if the RPC method is "Greeter.SayHello", RpcName would be "SayHello".
RpcName string `json:"rpc_name"`
// TitleName is the name of the method, without the "Say" prefix.
// For example, if the method is "SayHello", TitleName would be "Hello".
TitleName string `json:"title_name"`
// ServiceName is the name of the service.
// For example, if the service name is "Greeter", ServiceName would be "Greeter".
ServiceName string `json:"service_name"`
// MethodType is the type of the method, which can be "standard", "server-stream", "client-stream", or "bidirectional".
MethodType string `json:"method_type"`
// Input is the name of the input message for the method.
// For example, if the input message is "HelloRequest", Input would be "HelloRequest".
Input string `json:"input"`
// Output is the name of the output message for the method.
// For example, if the output message is "HelloResponse", Output would be "HelloResponse".
Output string `json:"output"`
}

Expand Down Expand Up @@ -237,7 +246,7 @@ func resolveDependencies(protos []*descriptorpb.FileDescriptorProto) map[string]
var (
// aliases is a map that keeps track of package aliases. The key is the alias and the value
// is a boolean indicating whether the alias is used or not.
aliases = map[string]bool{}
aliases = make(map[string]struct{}, 128)

// aliasNum is an integer that keeps track of the number of used aliases. It is used to
// generate new unique aliases.
Expand All @@ -246,7 +255,7 @@ var (
// packages is a map that stores the package names as keys and their corresponding aliases
// as values. The package names are the full go package names and the aliases are the
// generated or specified aliases for the packages.
packages = map[string]string{}
packages = make(map[string]string, 32)
)

// getGoPackage returns the go package alias and the go package name
Expand All @@ -262,6 +271,7 @@ var (
//
// The go_package option format is: package_name;alias.
func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPackage string) {
// Get the go_package option from the file descriptor.
goPackage = proto.GetOptions().GetGoPackage()
if goPackage == "" {
return
Expand Down Expand Up @@ -292,13 +302,16 @@ func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPack
}

// If the alias already exists, append a number to it.
if ok := aliases[alias]; ok {
if _, ok := aliases[alias]; ok {
alias = fmt.Sprintf("%s%d", alias, aliasNum)
aliasNum++
}

// Add the alias to the aliases map.
aliases[alias] = struct{}{}

// Add the package to the packages map with its alias.
packages[goPackage] = alias
aliases[alias] = true

return
}
Expand All @@ -307,10 +320,11 @@ func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPack
// a slice of Service structs, each representing a gRPC service.
//
// The function iterates over each file descriptor and extracts the services
// defined in each file. It then populates the Services struct with relevant
// information like the service name, package name, and methods. The methods
// include information such as the method name, input and output types, and the
// type of method (standard, server-stream, client-stream, or bidirectional).
// defined in each file. It populates the Service struct with relevant
// information like the service name, package name, and methods. Each method
// represents a gRPC method and includes information such as the method name,
// input and output types, and the type of method (standard, server-stream,
// client-stream, or bidirectional).
//
// Parameters:
// - protos: A slice of FileDescriptorProto structs representing the file
Expand All @@ -319,6 +333,7 @@ func getGoPackage(proto *descriptorpb.FileDescriptorProto) (alias string, goPack
// Returns:
// - svcTmp: A slice of Service structs representing the extracted services.
func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service {
// svcTmp will hold the extracted services
var svcTmp []Service
title := cases.Title(language.English, cases.NoLower)

Expand Down Expand Up @@ -350,7 +365,8 @@ func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service {

// Populate the methodTemplate struct
methods[j] = methodTemplate{
Name: title.String(*method.Name),
RpcName: method.GetName(),
TitleName: title.String(method.GetName()),
SvcPackage: s.Package,
ServiceName: svc.GetName(),
Input: getMessageType(protos, method.GetInputType()),
Expand All @@ -367,7 +383,16 @@ func extractServices(protos []*descriptorpb.FileDescriptorProto) []Service {
return svcTmp
}

// getMessageType takes a slice of file descriptors and a type string,
// and returns a fully qualified message type.
//
// The message type is split into package and type parts, and the
// function iterates over the protos to find the target message.
// If the target message is found, the function returns the fully
// qualified message type, otherwise it returns the target type.
func getMessageType(protos []*descriptorpb.FileDescriptorProto, tipe string) string {
title := cases.Title(language.English, cases.NoLower)

// Split the message type into package and type parts
split := strings.Split(tipe, ".")[1:]
targetPackage := strings.Join(split[:len(split)-1], ".")
Expand All @@ -391,13 +416,13 @@ func getMessageType(protos []*descriptorpb.FileDescriptorProto, tipe string) str
}

// Return the fully qualified message type
return fmt.Sprintf("%s%s", alias, msg.GetName())
return fmt.Sprintf("%s%s", alias, title.String(msg.GetName()))
}
}
}

// Return the target type if no match was found
return targetType
return targetPackage + "." + title.String(targetType)
}

// keywords is a map that contains all the reserved keywords in Go.
Expand Down
16 changes: 8 additions & 8 deletions protoc-gen-gripmock/server.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ type {{.Name}} struct{
{{end}}

{{ define "standard_method" }}
func (s *{{.ServiceName}}) {{.Name}}(ctx context.Context, in *{{.Input}}) (*{{.Output}},error){
func (s *{{.ServiceName}}) {{.TitleName}}(ctx context.Context, in *{{.Input}}) (*{{.Output}},error){
out := &{{.Output}}{}
// Retrieve metadata from the incoming context.
// The metadata is used to find the stub for the method being called.
Expand All @@ -159,15 +159,15 @@ func (s *{{.ServiceName}}) {{.Name}}(ctx context.Context, in *{{.Input}}) (*{{.O
// The stub defines the input and output messages for the method.
// If the stub is found, its output message is returned.
// If the stub is not found, an error is returned.
err := findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}", "{{.Name}}", md, in, out)
err := findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}", "{{.RpcName}}", md, in, out)

// Return the output message and any error encountered while finding the stub.
return out, err
}
{{ end }}

{{ define "server_stream_method" }}
func (s *{{.ServiceName}}) {{.Name}}(in *{{.Input}},srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error {
func (s *{{.ServiceName}}) {{.TitleName}}(in *{{.Input}},srv {{.SvcPackage}}{{.ServiceName}}_{{.TitleName}}Server) error {
out := &{{.Output}}{}
// Retrieve metadata from the incoming context.
// The metadata is used to find the stub for the method being called.
Expand All @@ -178,7 +178,7 @@ func (s *{{.ServiceName}}) {{.Name}}(in *{{.Input}},srv {{.SvcPackage}}{{.Servic
// The stub defines the input and output messages for the method.
// If the stub is found, its output message is returned.
// If the stub is not found, an error is returned.
err := findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}", "{{.Name}}", md, in, out)
err := findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}", "{{.RpcName}}", md, in, out)
if err != nil {
// Return the error encountered while finding the stub.
return err
Expand All @@ -192,7 +192,7 @@ func (s *{{.ServiceName}}) {{.Name}}(in *{{.Input}},srv {{.SvcPackage}}{{.Servic
{{ end }}

{{ define "client_stream_method"}}
func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error {
func (s *{{.ServiceName}}) {{.TitleName}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.TitleName}}Server) error {
out := &{{.Output}}{}
// Handle the client-streaming RPC.
// This loop will continue until the client closes the RPC.
Expand All @@ -214,7 +214,7 @@ func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name
// The stub defines the input and output messages for the method.
// If the stub is found, its output message is returned.
// If the stub is not found, an error is returned.
err = findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}","{{.Name}}", md, input, out)
err = findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}","{{.RpcName}}", md, input, out)
if err != nil {
// If there is an error finding the stub, return the error.
return err
Expand All @@ -224,7 +224,7 @@ func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name
{{ end }}

{{ define "bidirectional_method"}}
func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name}}Server) error {
func (s *{{.ServiceName}}) {{.TitleName}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.TitleName}}Server) error {
// Handle the bidirectional RPC.
// This loop will continue until the client closes the RPC.
// For each input message received from the client, it will find the stub
Expand Down Expand Up @@ -252,7 +252,7 @@ func (s *{{.ServiceName}}) {{.Name}}(srv {{.SvcPackage}}{{.ServiceName}}_{{.Name
// The stub defines the input and output messages for the method.
// If the stub is found, its output message is returned.
// If the stub is not found, an error is returned.
err = findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}","{{.Name}}", md, input, out)
err = findStub(ctx, s.__builder__.Config().HTTPAddr, "{{.ServiceName}}","{{.RpcName}}", md, input, out)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 83d4717

Please sign in to comment.