Skip to content

Commit

Permalink
work on policies
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-galataud committed Mar 27, 2024
1 parent 7d3ef64 commit 86f86b4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hopes/policy/action_probs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from sab_ope.policy.policies import Policy
from hopes.policy.policies import Policy


def compute_action_probs_from_policy(policy: Policy, obs: np.ndarray) -> np.ndarray:
Expand Down
20 changes: 19 additions & 1 deletion hopes/policy/policies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from abc import ABC, abstractmethod

import numpy as np


class Policy(ABC):

@abstractmethod
def log_likelihoods(self, obs):
def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
raise NotImplementedError


class RandomPolicy(Policy):

def __init__(self, num_actions: int):
assert num_actions > 0, "Number of actions must be positive."
self.num_actions = num_actions

def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
assert obs.ndim == 2, "Observations must have shape (batch_size, obs_dim)."
assert obs.shape[1] > 0, "Observations must have shape (batch_size, obs_dim)."

action_probs = np.random.rand(obs.shape[0], self.num_actions)
action_probs /= action_probs.sum(axis=1, keepdims=True)
return np.log(action_probs)

33 changes: 33 additions & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

import numpy as np

from hopes.policy.action_probs import compute_action_probs_from_policy
from hopes.policy.policies import RandomPolicy


class TestPolicies(unittest.TestCase):

def test_rnd_policy(self):
rnd_policy = RandomPolicy(num_actions=3)

log_probs = rnd_policy.log_likelihoods(obs=np.random.rand(10, 5))
self.assertIsInstance(log_probs, np.ndarray)
self.assertEqual(log_probs.shape, (10, 3))
self.assertTrue(np.all(log_probs <= 0.0))
self.assertTrue(np.all(log_probs >= -np.inf))

act_probs = np.exp(log_probs)
self.assertTrue(np.all(act_probs >= 0.0))
self.assertTrue(np.all(act_probs <= 1.0))
self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0))

def test_compute_action_probs(self):
rnd_policy = RandomPolicy(num_actions=3)
act_probs = compute_action_probs_from_policy(rnd_policy, obs=np.random.rand(10, 5))

self.assertIsInstance(act_probs, np.ndarray)
self.assertEqual(act_probs.shape, (10, 3))
self.assertTrue(np.all(act_probs >= 0.0))
self.assertTrue(np.all(act_probs <= 1.0))
self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0))

0 comments on commit 86f86b4

Please sign in to comment.