diff --git a/hopes/policy/action_probs.py b/hopes/policy/action_probs.py index ade371e..d4d08fb 100644 --- a/hopes/policy/action_probs.py +++ b/hopes/policy/action_probs.py @@ -1,6 +1,6 @@ import numpy as np -from sab_ope.policy.policies import Policy +from hopes.policy.policies import Policy def compute_action_probs_from_policy(policy: Policy, obs: np.ndarray) -> np.ndarray: diff --git a/hopes/policy/policies.py b/hopes/policy/policies.py index 1638293..0730402 100644 --- a/hopes/policy/policies.py +++ b/hopes/policy/policies.py @@ -1,8 +1,26 @@ from abc import ABC, abstractmethod +import numpy as np + class Policy(ABC): @abstractmethod - def log_likelihoods(self, obs): + def log_likelihoods(self, obs: np.ndarray) -> np.ndarray: raise NotImplementedError + + +class RandomPolicy(Policy): + + def __init__(self, num_actions: int): + assert num_actions > 0, "Number of actions must be positive." + self.num_actions = num_actions + + def log_likelihoods(self, obs: np.ndarray) -> np.ndarray: + assert obs.ndim == 2, "Observations must have shape (batch_size, obs_dim)." + assert obs.shape[1] > 0, "Observations must have shape (batch_size, obs_dim)." + + action_probs = np.random.rand(obs.shape[0], self.num_actions) + action_probs /= action_probs.sum(axis=1, keepdims=True) + return np.log(action_probs) + diff --git a/tests/test_policies.py b/tests/test_policies.py new file mode 100644 index 0000000..f2958d8 --- /dev/null +++ b/tests/test_policies.py @@ -0,0 +1,33 @@ +import unittest + +import numpy as np + +from hopes.policy.action_probs import compute_action_probs_from_policy +from hopes.policy.policies import RandomPolicy + + +class TestPolicies(unittest.TestCase): + + def test_rnd_policy(self): + rnd_policy = RandomPolicy(num_actions=3) + + log_probs = rnd_policy.log_likelihoods(obs=np.random.rand(10, 5)) + self.assertIsInstance(log_probs, np.ndarray) + self.assertEqual(log_probs.shape, (10, 3)) + self.assertTrue(np.all(log_probs <= 0.0)) + self.assertTrue(np.all(log_probs >= -np.inf)) + + act_probs = np.exp(log_probs) + self.assertTrue(np.all(act_probs >= 0.0)) + self.assertTrue(np.all(act_probs <= 1.0)) + self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0)) + + def test_compute_action_probs(self): + rnd_policy = RandomPolicy(num_actions=3) + act_probs = compute_action_probs_from_policy(rnd_policy, obs=np.random.rand(10, 5)) + + self.assertIsInstance(act_probs, np.ndarray) + self.assertEqual(act_probs.shape, (10, 3)) + self.assertTrue(np.all(act_probs >= 0.0)) + self.assertTrue(np.all(act_probs <= 1.0)) + self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0))