Skip to content

Commit

Permalink
fix: improve key scanning (#6374)
Browse files Browse the repository at this point in the history
Description
---
Improves utxo scanning with various fixes

Motivation and Context
---
Scanning for new outputs on the blockchain can be very slow.

How Has This Been Tested?
---
Manual
  • Loading branch information
SWvheerden committed Jun 20, 2024
1 parent b773173 commit 43b2317
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 76 deletions.
8 changes: 4 additions & 4 deletions base_layer/common_types/src/wallet_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ impl Default for WalletType {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LedgerWallet {
account: u64,
pub pubkey: Option<RistrettoPublicKey>,
pub public_alpha: Option<RistrettoPublicKey>,
pub network: Network,
}

impl Display for LedgerWallet {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "account {}", self.account)?;
write!(f, "pubkey {}", self.pubkey.is_some())?;
write!(f, "pubkey {}", self.public_alpha.is_some())?;
Ok(())
}
}
Expand All @@ -81,10 +81,10 @@ impl Display for LedgerWallet {
const WALLET_CLA: u8 = 0x80;

impl LedgerWallet {
pub fn new(account: u64, network: Network, pubkey: Option<RistrettoPublicKey>) -> Self {
pub fn new(account: u64, network: Network, public_alpha: Option<RistrettoPublicKey>) -> Self {
Self {
account,
pubkey,
public_alpha,
network,
}
}
Expand Down
54 changes: 35 additions & 19 deletions base_layer/core/src/transactions/key_manager/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use tari_utilities::{hex::Hex, ByteArray};
use tokio::sync::RwLock;

const LOG_TARGET: &str = "c::bn::key_manager::key_manager_service";
const KEY_MANAGER_MAX_SEARCH_DEPTH: u64 = 1_000_000;
const TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH: u64 = 1_000_000;

use crate::{
common::ConfidentialOutputHasher,
Expand Down Expand Up @@ -242,9 +242,11 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
KeyId::Derived { branch, label, index } => {
let public_alpha = match &self.wallet_type {
WalletType::Software(_k, pk) => pk,
WalletType::Ledger(ledger) => ledger.pubkey.as_ref().ok_or(KeyManagerServiceError::LedgerError(
"Key manager set to use ledger, ledger alpha public key missing".to_string(),
))?,
WalletType::Ledger(ledger) => {
ledger.public_alpha.as_ref().ok_or(KeyManagerServiceError::LedgerError(
"Key manager set to use ledger, ledger alpha public key missing".to_string(),
))?
},
};
let km = self
.key_managers
Expand Down Expand Up @@ -341,11 +343,20 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static

let current_index = km.key_index();

for i in 0u64..current_index + KEY_MANAGER_MAX_SEARCH_DEPTH {
let public_key = PublicKey::from_secret_key(&km.derive_key(i)?.key);
for i in 0u64..TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH {
let index = current_index + i;
let public_key = PublicKey::from_secret_key(&km.derive_key(index)?.key);
if public_key == *key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, i);
return Ok(i);
return Ok(index);
}
if i <= current_index && i != 0u64 {
let index = current_index - i;
let public_key = PublicKey::from_secret_key(&km.derive_key(index)?.key);
if public_key == *key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
}
}

Expand All @@ -363,11 +374,21 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static

let current_index = km.key_index();

for i in 0u64..current_index + KEY_MANAGER_MAX_SEARCH_DEPTH {
let private_key = &km.derive_key(i)?.key;
// its most likely that the key is close to the current index, so we start searching from the current index
for i in 0u64..TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH {
let index = current_index + i;
let private_key = &km.derive_key(index)?.key;
if private_key == key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, i);
return Ok(i);
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
if i <= current_index && i != 0u64 {
let index = current_index - i;
let private_key = &km.derive_key(index)?.key;
if private_key == key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
}
}

Expand Down Expand Up @@ -418,7 +439,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
},
KeyId::Derived { branch, label, index } => match &self.wallet_type {
WalletType::Ledger(_) => Err(KeyManagerServiceError::LedgerPrivateKeyInaccessible),
WalletType::Software(k, _pk) => {
WalletType::Software(private_alpha, _pk) => {
let km = self
.key_managers
.get(branch)
Expand All @@ -431,7 +452,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
let private_key = PrivateKey::from_uniform_bytes(hasher.as_ref()).map_err(|_| {
KeyManagerServiceError::UnknownError(format!("Invalid private key for {}", label))
})?;
let private_key = private_key + k;
let private_key = private_key + private_alpha;
Ok(private_key)
},
},
Expand Down Expand Up @@ -1218,12 +1239,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
self.crypto_factories
.range_proof
.verify_mask(output.commitment(), &private_key, value.into())?;
// Detect the branch we need to scan on for the key.
let branch = if output.is_coinbase() {
TransactionKeyManagerBranch::Coinbase.get_branch_key()
} else {
TransactionKeyManagerBranch::CommitmentMask.get_branch_key()
};
let branch = TransactionKeyManagerBranch::CommitmentMask.get_branch_key();
let key = match self.find_private_key_index(&branch, &private_key).await {
Ok(index) => {
self.update_current_key_index_if_higher(&branch, index).await?;
Expand Down
3 changes: 0 additions & 3 deletions base_layer/core/src/transactions/key_manager/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pub enum TxoStage {
#[derive(Clone, Copy, EnumIter)]
pub enum TransactionKeyManagerBranch {
DataEncryption = 0x00,
Coinbase = 0x01,
MetadataEphemeralNonce = 0x02,
CommitmentMask = 0x03,
Nonce = 0x04,
Expand All @@ -71,7 +70,6 @@ impl TransactionKeyManagerBranch {
pub fn get_branch_key(self) -> String {
match self {
TransactionKeyManagerBranch::DataEncryption => "data encryption".to_string(),
TransactionKeyManagerBranch::Coinbase => "coinbase".to_string(),
TransactionKeyManagerBranch::CommitmentMask => "commitment mask".to_string(),
TransactionKeyManagerBranch::Nonce => "nonce".to_string(),
TransactionKeyManagerBranch::MetadataEphemeralNonce => "metadata ephemeral nonce".to_string(),
Expand All @@ -83,7 +81,6 @@ impl TransactionKeyManagerBranch {
pub fn from_key(key: &str) -> Self {
match key {
"data encryption" => TransactionKeyManagerBranch::DataEncryption,
"coinbase" => TransactionKeyManagerBranch::Coinbase,
"commitment mask" => TransactionKeyManagerBranch::CommitmentMask,
"metadata ephemeral nonce" => TransactionKeyManagerBranch::MetadataEphemeralNonce,
"kernel nonce" => TransactionKeyManagerBranch::KernelNonce,
Expand Down
2 changes: 1 addition & 1 deletion base_layer/core/src/validation/block_body/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async fn it_allows_multiple_coinbases() {

let (mut block, coinbase) = blockchain.create_unmined_block(block_spec!("A1", parent: "GB")).await;
let spend_key_id = KeyId::Managed {
branch: TransactionKeyManagerBranch::Coinbase.get_branch_key(),
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
index: 42,
};
let wallet_payment_address = TariAddress::default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ where
if start.elapsed().as_millis() > 0 {
trace!(
target: LOG_TARGET,
"sqlite profile - insert_imported_key: lock {} + db_op {} = {} ms",
"sqlite profile - get_imported_key: lock {} + db_op {} = {} ms",
acquire_lock.as_millis(),
(start.elapsed() - acquire_lock).as_millis(),
start.elapsed().as_millis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
use std::time::Instant;

use log::*;
use tari_common_types::{transaction::TxId, types::FixedHash};
use tari_common_types::{
transaction::TxId,
types::{FixedHash, PrivateKey},
};
use tari_core::transactions::{
key_manager::{TariKeyId, TransactionKeyManagerBranch, TransactionKeyManagerInterface, TransactionKeyManagerLabel},
tari_amount::MicroMinotari,
Expand All @@ -35,6 +38,7 @@ use tari_core::transactions::{
WalletOutput,
},
};
use tari_crypto::keys::SecretKey;
use tari_key_manager::key_manager_service::KeyId;
use tari_script::{inputs, script, ExecutionStack, Opcode, TariScript};
use tari_utilities::hex::Hex;
Expand Down Expand Up @@ -156,8 +160,6 @@ where
tx_id,
hash: *hash,
});
self.update_outputs_script_private_key_and_update_key_manager_index(output)
.await?;
trace!(
target: LOG_TARGET,
"Output {} with value {} with {} recovered",
Expand Down Expand Up @@ -200,11 +202,16 @@ where
known_scripts: &[KnownOneSidedPaymentScript],
) -> Result<Option<(ExecutionStack, TariKeyId)>, OutputManagerError> {
let (input_data, script_key) = if script == &script!(Nop) {
// This is a nop, so we can just create a new key an create the input stack.
let key = KeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index: spending_key.managed_index().unwrap(),
// This is a nop, so we can just create a new key for the input stack.
let key = if let Some(index) = spending_key.managed_index() {
KeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index,
}
} else {
let private_key = PrivateKey::random(&mut rand::thread_rng());
self.master_key_manager.import_key(private_key).await?
};
let public_key = self.master_key_manager.get_public_key_at_key_id(&key).await?;
(inputs!(public_key), key)
Expand Down Expand Up @@ -259,43 +266,4 @@ where

Ok(Some((key, committed_value, payment_id)))
}

/// Find the key manager index that corresponds to the spending key in the rewound output, if found then modify
/// output to contain correct associated script private key and update the key manager to the highest index it has
/// seen so far.
async fn update_outputs_script_private_key_and_update_key_manager_index(
&mut self,
output: &mut WalletOutput,
) -> Result<(), OutputManagerError> {
let public_key = self
.master_key_manager
.get_public_key_at_key_id(&output.spending_key_id)
.await?;
let script_key = {
let found_index = self
.master_key_manager
.find_key_index(
TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
&public_key,
)
.await?;

self.master_key_manager
.update_current_key_index_if_higher(
TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
found_index,
)
.await?;

TariKeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index: found_index,
}
};
let public_script_key = self.master_key_manager.get_public_key_at_key_id(&script_key).await?;
output.input_data = inputs!(public_script_key);
output.script_key_id = script_key;
Ok(())
}
}
11 changes: 10 additions & 1 deletion base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ where
));

let timer = Instant::now();

loop {
let tip_header = self.get_chain_tip_header(&mut client).await?;
let tip_header_hash = tip_header.hash();
Expand Down Expand Up @@ -563,6 +562,7 @@ where
height: u64,
) -> Result<Vec<(WalletOutput, String, ImportStatus, TxId, TransactionOutput)>, UtxoScannerError> {
let mut found_outputs: Vec<(WalletOutput, String, ImportStatus, TxId, TransactionOutput)> = Vec::new();
let start = Instant::now();
found_outputs.append(
&mut self
.resources
Expand All @@ -586,6 +586,8 @@ where
})
.collect::<Result<Vec<_>, _>>()?,
);
let scanned_time = start.elapsed();
let start = Instant::now();

found_outputs.append(
&mut self
Expand Down Expand Up @@ -613,6 +615,13 @@ where
})
.collect::<Result<Vec<_>, _>>()?,
);
let one_sided_time = start.elapsed();
trace!(
target: LOG_TARGET,
"Scanned for outputs: outputs took {} ms , one-sided took {} ms",
scanned_time.as_millis(),
one_sided_time.as_millis(),
);
Ok(found_outputs)
}

Expand Down

0 comments on commit 43b2317

Please sign in to comment.