Skip to content

Commit

Permalink
Replace pickle in state persistence in provision cert with json (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA committed Apr 19, 2022
1 parent f0a0059 commit fd018ee
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import datetime
import json
import os
import pickle

from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -50,13 +50,13 @@ def __init__(self):

def initialize(self, ctx):
state_dir = self.get_state_dir(ctx)
cert_file = os.path.join(state_dir, "cert.pkl")
cert_file = os.path.join(state_dir, "cert.json")
if os.path.exists(cert_file):
self.persistent_state = pickle.load(open(cert_file, "rb"))
self.serialized_cert = self.persistent_state["root_cert"]
self.persistent_state = json.load(open(cert_file, "rt"))
self.serialized_cert = self.persistent_state["root_cert"].encode("ascii")
self.root_cert = x509.load_pem_x509_certificate(self.serialized_cert, default_backend())
self.pri_key = serialization.load_pem_private_key(
self.persistent_state["root_pri_key"], password=None, backend=default_backend()
self.persistent_state["root_pri_key"].encode("ascii"), password=None, backend=default_backend()
)
self.pub_key = self.pri_key.public_key()
self.subject = self.root_cert.subject
Expand All @@ -69,26 +69,30 @@ def _build_root(self, subject):
self.pri_key = pri_key
self.pub_key = pub_key
self.serialized_cert = serialize_cert(self.root_cert)
self.persistent_state["root_cert"] = self.serialized_cert
self.persistent_state["root_pri_key"] = serialize_pri_key(self.pri_key)
self.persistent_state["root_cert"] = self.serialized_cert.decode("ascii")
self.persistent_state["root_pri_key"] = serialize_pri_key(self.pri_key).decode("ascii")

def _build_write_cert_pair(self, participant, base_name, ctx):
subject = participant.subject
if self.persistent_state and subject in self.persistent_state:
cert = x509.load_pem_x509_certificate(self.persistent_state[subject]["cert"], default_backend())
cert = x509.load_pem_x509_certificate(
self.persistent_state[subject]["cert"].encode("ascii"), default_backend()
)
pri_key = serialization.load_pem_private_key(
self.persistent_state[subject]["pri_key"], password=None, backend=default_backend()
self.persistent_state[subject]["pri_key"].encode("ascii"), password=None, backend=default_backend()
)
else:
pri_key, cert = self.get_pri_key_cert(participant)
self.persistent_state[subject] = dict(cert=serialize_cert(cert), pri_key=serialize_pri_key(pri_key))
self.persistent_state[subject] = dict(
cert=serialize_cert(cert).decode("ascii"), pri_key=serialize_pri_key(pri_key).decode("ascii")
)
dest_dir = self.get_kit_dir(participant, ctx)
with open(os.path.join(dest_dir, f"{base_name}.crt"), "wb") as f:
f.write(serialize_cert(cert))
with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f:
f.write(serialize_pri_key(pri_key))
pkcs12 = serialization.pkcs12.serialize_key_and_certificates(
subject.encode("utf-8"), pri_key, cert, None, serialization.BestAvailableEncryption(subject.encode("utf-8"))
subject.encode("ascii"), pri_key, cert, None, serialization.BestAvailableEncryption(subject.encode("ascii"))
)
with open(os.path.join(dest_dir, f"{base_name}.pfx"), "wb") as f:
f.write(pkcs12)
Expand Down Expand Up @@ -163,5 +167,5 @@ def _x509_name(self, cn_name, org_name=None):

def finalize(self, ctx):
state_dir = self.get_state_dir(ctx)
cert_file = os.path.join(state_dir, "cert.pkl")
pickle.dump(self.persistent_state, open(cert_file, "wb"))
cert_file = os.path.join(state_dir, "cert.json")
json.dump(self.persistent_state, open(cert_file, "wt"))

0 comments on commit fd018ee

Please sign in to comment.