Skip to content

Commit

Permalink
Implement predefined field constraints (#61)
Browse files Browse the repository at this point in the history
Like protovalidate-go and protovalidate-java, we need to adjust the code
to handle dynamic descriptor sets more robustly, since we need to jump
between resolving the protovalidate standard rules and the predefined
rule extensions. This necessitates adding a couple of additions to the
API surface, namely `ValidatorFactory::SetMessageFactory` and
`ValidatorFactory::SetAllowUnknownFields`, which controls instantiation
of unknown dynamic types and whether or not to ignore unresolved rules,
respectively. Like other protovalidate runtimes, we will default to
failing compilation when unknown predefined rules are encountered. This
should not break existing users but will prevent silent incorrect
behavior.

TODO:
- [x] Skip reparse when there are no empty fields—this way we can
avoid pessimizing the common case
- [x] Add an option to fail when unknown rule fields are unable to be
resolved.
- [x] Update for protobuf changes in
bufbuild/protovalidate#246.

This will depend on bufbuild/protovalidate#246.
  • Loading branch information
jchadwick-buf committed Sep 23, 2024
1 parent a01c79d commit b03f050
Show file tree
Hide file tree
Showing 20 changed files with 306 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ COPYRIGHT_YEARS := 2023
LICENSE_IGNORE := -e internal/testdata/
LICENSE_HEADER_VERSION := 0294fdbe1ce8649ebaf5e87e8cdd588e33730bbb
# NOTE: Keep this version in sync with the version in `/bazel/deps.bzl`.
PROTOVALIDATE_VERSION ?= v0.7.1
PROTOVALIDATE_VERSION ?= v0.8.1

# Set to use a different compiler. For example, `GO=go1.18rc1 make test`.
GO ?= go
Expand Down
6 changes: 3 additions & 3 deletions bazel/deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ _dependencies = {
},
# NOTE: Keep Version in sync with `/Makefile`.
"com_github_bufbuild_protovalidate": {
"sha256": "ccb3952c38397d2cb53fe841af66b05fc012dd17fa754cbe35d9abb547cdf92d",
"strip_prefix": "protovalidate-0.7.1",
"sha256": "c637c8cbaf71b6dc38171e47c2c736581b4cfef385984083561480367659d14f",
"strip_prefix": "protovalidate-0.8.1",
"urls": [
"https://github.com/bufbuild/protovalidate/archive/v0.7.1.tar.gz",
"https://github.com/bufbuild/protovalidate/archive/v0.8.1.tar.gz",
],
},
}
Expand Down
5 changes: 2 additions & 3 deletions buf/validate/conformance/runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
namespace buf::validate::conformance {

harness::TestConformanceResponse TestRunner::runTest(
const harness::TestConformanceRequest& request,
const google::protobuf::DescriptorPool* descriptorPool) {
const harness::TestConformanceRequest& request) {
harness::TestConformanceResponse response;
for (const auto& tc : request.cases()) {
auto& result = response.mutable_results()->operator[](tc.first);
Expand All @@ -32,7 +31,7 @@ harness::TestConformanceResponse TestRunner::runTest(
*result.mutable_unexpected_error() = "could not parse type url " + dyn.type_url();
continue;
}
const auto* desc = descriptorPool->FindMessageTypeByName(dyn.type_url().substr(pos + 1));
const auto* desc = descriptorPool_->FindMessageTypeByName(dyn.type_url().substr(pos + 1));
if (desc == nullptr) {
*result.mutable_unexpected_error() = "could not find descriptor for type " + dyn.type_url();
} else {
Expand Down
13 changes: 9 additions & 4 deletions buf/validate/conformance/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ namespace buf::validate::conformance {

class TestRunner {
public:
explicit TestRunner() : validatorFactory_(ValidatorFactory::New().value()) {}
explicit TestRunner(
const google::protobuf::DescriptorPool* descriptorPool =
google::protobuf::DescriptorPool::generated_pool())
: descriptorPool_(descriptorPool), validatorFactory_(ValidatorFactory::New().value()) {
validatorFactory_->SetMessageFactory(&messageFactory_, descriptorPool_);
validatorFactory_->SetAllowUnknownFields(false);
}

harness::TestConformanceResponse runTest(
const harness::TestConformanceRequest& request,
const google::protobuf::DescriptorPool* descriptorPool);
harness::TestConformanceResponse runTest(const harness::TestConformanceRequest& request);
harness::TestResult runTestCase(
const google::protobuf::Descriptor* desc, const google::protobuf::Any& dyn);
harness::TestResult runTestCase(const google::protobuf::Message& message);

private:
google::protobuf::DynamicMessageFactory messageFactory_;
const google::protobuf::DescriptorPool* descriptorPool_;
std::unique_ptr<ValidatorFactory> validatorFactory_;
google::protobuf::Arena arena_;
};
Expand Down
7 changes: 4 additions & 3 deletions buf/validate/conformance/runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
#include "buf/validate/conformance/runner.h"

int main(int argc, char** argv) {
google::protobuf::DescriptorPool descriptorPool;
buf::validate::conformance::TestRunner runner;
google::protobuf::DescriptorPool descriptorPool{
google::protobuf::DescriptorPool::generated_pool()};
buf::validate::conformance::harness::TestConformanceRequest request;
request.ParseFromIstream(&std::cin);
for (const auto& file : request.fdset().file()) {
descriptorPool.BuildFile(file);
}
auto response = runner.runTest(request, &descriptorPool);
buf::validate::conformance::TestRunner runner{&descriptorPool};
auto response = runner.runTest(request);
response.SerializeToOstream(&std::cout);
return 0;
}
15 changes: 15 additions & 0 deletions buf/validate/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ cc_library(
"@com_google_cel_cpp//eval/public:activation",
"@com_google_cel_cpp//eval/public:cel_expression",
"@com_google_cel_cpp//eval/public/structs:cel_proto_wrapper",
"@com_google_cel_cpp//eval/public/containers:field_access",
"@com_google_cel_cpp//eval/public/containers:field_backed_list_impl",
"@com_google_cel_cpp//eval/public/containers:field_backed_map_impl",
"@com_google_cel_cpp//parser",
"@com_google_cel_cpp//base:value"
],
)

Expand All @@ -44,15 +48,26 @@ cc_library(
deps = [
"@com_google_absl//absl/status",
"@com_google_protobuf//:protobuf",
":message_factory",
],
)

cc_library(
name = "message_factory",
srcs = ["message_factory.cc"],
hdrs = ["message_factory.h"],
deps = [
"@com_google_protobuf//:protobuf",
]
)

cc_library(
name = "message_rules",
srcs = ["message_rules.cc"],
hdrs = ["message_rules.h"],
deps = [
":field_rules",
":message_factory",
"@com_github_bufbuild_protovalidate//proto/protovalidate/buf/validate:validate_proto_cc",
"@com_google_absl//absl/status",
"@com_google_cel_cpp//eval/public:cel_expression",
Expand Down
40 changes: 36 additions & 4 deletions buf/validate/internal/cel_constraint_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

#include "buf/validate/internal/cel_constraint_rules.h"

#include "base/values/struct_value.h"
#include "eval/public/containers/field_access.h"
#include "eval/public/containers/field_backed_list_impl.h"
#include "eval/public/containers/field_backed_map_impl.h"
#include "eval/public/structs/cel_proto_wrapper.h"
#include "parser/parser.h"

Expand Down Expand Up @@ -57,10 +61,31 @@ absl::Status ProcessConstraint(
return absl::OkStatus();
}

cel::runtime::CelValue ProtoFieldToCelValue(
const google::protobuf::Message* message,
const google::protobuf::FieldDescriptor* field,
google::protobuf::Arena* arena) {
if (field->is_map()) {
return cel::runtime::CelValue::CreateMap(
google::protobuf::Arena::Create<cel::runtime::FieldBackedMapImpl>(
arena, message, field, arena));
} else if (field->is_repeated()) {
return cel::runtime::CelValue::CreateList(
google::protobuf::Arena::Create<cel::runtime::FieldBackedListImpl>(
arena, message, field, arena));
} else if (cel::runtime::CelValue result;
cel::runtime::CreateValueFromSingleField(message, field, arena, &result).ok()) {
return result;
}
return cel::runtime::CelValue::CreateNull();
}

} // namespace

absl::Status CelConstraintRules::Add(
google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint) {
google::api::expr::runtime::CelExpressionBuilder& builder,
Constraint constraint,
const google::protobuf::FieldDescriptor* rule) {
auto pexpr_or = cel::parser::Parse(constraint.expression());
if (!pexpr_or.ok()) {
return pexpr_or.status();
Expand All @@ -71,20 +96,21 @@ absl::Status CelConstraintRules::Add(
return expr_or.status();
}
std::unique_ptr<cel::runtime::CelExpression> expr = std::move(expr_or).value();
exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr)});
exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr), rule});
return absl::OkStatus();
}

absl::Status CelConstraintRules::Add(
google::api::expr::runtime::CelExpressionBuilder& builder,
std::string_view id,
std::string_view message,
std::string_view expression) {
std::string_view expression,
const google::protobuf::FieldDescriptor* rule) {
Constraint constraint;
*constraint.mutable_id() = id;
*constraint.mutable_message() = message;
*constraint.mutable_expression() = expression;
return Add(builder, constraint);
return Add(builder, constraint, rule);
}

absl::Status CelConstraintRules::ValidateCel(
Expand All @@ -94,11 +120,17 @@ absl::Status CelConstraintRules::ValidateCel(
activation.InsertValue("rules", rules_);
activation.InsertValue("now", cel::runtime::CelValue::CreateTimestamp(absl::Now()));
absl::Status status = absl::OkStatus();

for (const auto& expr : exprs_) {
if (rules_.IsMessage() && expr.rule) {
activation.InsertValue(
"rule", ProtoFieldToCelValue(rules_.MessageOrDie(), expr.rule, ctx.arena));
}
status = ProcessConstraint(ctx, fieldName, activation, expr);
if (ctx.shouldReturn(status)) {
break;
}
activation.RemoveValueEntry("rule");
}
activation.RemoveValueEntry("rules");
return status;
Expand Down
10 changes: 7 additions & 3 deletions buf/validate/internal/cel_constraint_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <string_view>

#include "buf/validate/expression.pb.h"
#include "buf/validate/validate.pb.h"
#include "buf/validate/internal/constraint_rules.h"
#include "eval/public/activation.h"
#include "eval/public/cel_expression.h"
Expand All @@ -28,6 +28,7 @@ namespace buf::validate::internal {
struct CompiledConstraint {
buf::validate::Constraint constraint;
std::unique_ptr<google::api::expr::runtime::CelExpression> expr;
const google::protobuf::FieldDescriptor* rule;
};

// An abstract base class for constraint with rules that are compiled into CEL expressions.
Expand All @@ -38,12 +39,15 @@ class CelConstraintRules : public ConstraintRules {
using Base::Base;

absl::Status Add(
google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint);
google::api::expr::runtime::CelExpressionBuilder& builder,
Constraint constraint,
const google::protobuf::FieldDescriptor* rule);
absl::Status Add(
google::api::expr::runtime::CelExpressionBuilder& builder,
std::string_view id,
std::string_view message,
std::string_view expression);
std::string_view expression,
const google::protobuf::FieldDescriptor* rule);
[[nodiscard]] const std::vector<CompiledConstraint>& getExprs() const { return exprs_; }

// Validate all the cel rules given the activation that already has 'this' bound.
Expand Down
39 changes: 33 additions & 6 deletions buf/validate/internal/cel_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "absl/status/status.h"
#include "buf/validate/internal/cel_constraint_rules.h"
#include "buf/validate/internal/message_factory.h"
#include "buf/validate/validate.pb.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
Expand All @@ -24,22 +25,48 @@ namespace buf::validate::internal {

template <typename R>
absl::Status BuildCelRules(
std::unique_ptr<MessageFactory>& messageFactory,
bool allowUnknownFields,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const R& rules,
CelConstraintRules& result) {
result.setRules(&rules, arena);
// Look for constraints on the set fields.
std::vector<const google::protobuf::FieldDescriptor*> fields;
R::GetReflection()->ListFields(rules, &fields);
google::protobuf::Message* reparsedRules{};
if (messageFactory && rules.unknown_fields().field_count() > 0) {
reparsedRules = messageFactory->messageFactory()
->GetPrototype(messageFactory->descriptorPool()->FindMessageTypeByName(
rules.GetTypeName()))
->New(arena);
if (!Reparse(*messageFactory, rules, reparsedRules)) {
reparsedRules = nullptr;
}
}
if (reparsedRules) {
if (!allowUnknownFields &&
!reparsedRules->GetReflection()->GetUnknownFields(*reparsedRules).empty()) {
return absl::FailedPreconditionError(
absl::StrCat("unknown constraints in ", reparsedRules->GetTypeName()));
}
result.setRules(reparsedRules, arena);
reparsedRules->GetReflection()->ListFields(*reparsedRules, &fields);
} else {
if (!allowUnknownFields && !R::GetReflection()->GetUnknownFields(rules).empty()) {
return absl::FailedPreconditionError(
absl::StrCat("unknown constraints in ", rules.GetTypeName()));
}
result.setRules(&rules, arena);
R::GetReflection()->ListFields(rules, &fields);
}
for (const auto* field : fields) {
if (!field->options().HasExtension(buf::validate::priv::field)) {
if (!field->options().HasExtension(buf::validate::predefined)) {
continue;
}
const auto& fieldLvl = field->options().GetExtension(buf::validate::priv::field);
const auto& fieldLvl = field->options().GetExtension(buf::validate::predefined);
for (const auto& constraint : fieldLvl.cel()) {
auto status =
result.Add(builder, constraint.id(), constraint.message(), constraint.expression());
auto status = result.Add(
builder, constraint.id(), constraint.message(), constraint.expression(), field);
if (!status.ok()) {
return status;
}
Expand Down
2 changes: 1 addition & 1 deletion buf/validate/internal/constraint_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once

#include "absl/status/status.h"
#include "buf/validate/expression.pb.h"
#include "buf/validate/validate.pb.h"
#include "eval/public/cel_value.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/message.h"
Expand Down
4 changes: 0 additions & 4 deletions buf/validate/internal/constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@

#include "absl/status/statusor.h"
#include "buf/validate/internal/extra_func.h"
#include "buf/validate/priv/private.pb.h"
#include "buf/validate/validate.pb.h"
#include "eval/public/builtin_func_registrar.h"
#include "eval/public/cel_expr_builder_factory.h"
#include "eval/public/cel_value.h"
#include "eval/public/containers/field_access.h"
#include "eval/public/containers/field_backed_list_impl.h"
#include "eval/public/containers/field_backed_map_impl.h"
#include "eval/public/structs/cel_proto_wrapper.h"
#include "google/protobuf/any.pb.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/util/message_differencer.h"

Expand Down
2 changes: 1 addition & 1 deletion buf/validate/internal/constraints_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ExpressionTest : public testing::Test {
constraint.set_expression(std::move(expr));
constraint.set_message(std::move(message));
constraint.set_id(std::move(id));
return constraints_->Add(*builder_, constraint);
return constraints_->Add(*builder_, constraint, nullptr);
}

absl::Status Validate(
Expand Down
Loading

0 comments on commit b03f050

Please sign in to comment.