Skip to content

Commit

Permalink
add more policies
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-galataud committed Mar 27, 2024
1 parent 54c1d81 commit ab6e18c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
35 changes: 34 additions & 1 deletion hopes/policy/policies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod

import numpy as np
from sklearn.linear_model import LogisticRegression

from hopes.dev_utils import override


class Policy(ABC):
Expand Down Expand Up @@ -35,12 +38,41 @@ def __init__(self, num_actions: int):
assert num_actions > 0, "Number of actions must be positive."
self.num_actions = num_actions

@override(Policy)
def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
action_probs = np.random.rand(obs.shape[0], self.num_actions)
action_probs /= action_probs.sum(axis=1, keepdims=True)
return np.log(action_probs)


class RegressionBasedPolicy(Policy):
"""A policy that uses a regression model to predict the log-likelihoods of actions given
observations."""

def __init__(
self, obs: np.ndarray, act: np.ndarray, regression_model: str = "logistic"
) -> None:
"""
:param obs: the observations for training the regression model, shape: (batch_size, obs_dim).
:param act: the actions for training the regression model, shape: (batch_size,).
:param regression_model: the type of regression model to use. For now, only logistic is supported.
"""
assert regression_model in ["logistic"], "Only logistic regression is supported for now."
assert obs.ndim == 2, "Observations must have shape (batch_size, obs_dim)."
assert obs.shape[0] == act.shape[0], "Number of observations and actions must match."

self.model_x = obs
self.model_y = act
self.model = LogisticRegression()

def fit(self):
self.model.fit(self.model_x, self.model_y)

@override(Policy)
def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
return self.model.predict_log_proba(obs)


class HttpPolicy(Policy):
"""A policy that uses a remote HTTP server that returns log likelihoods for actions given
observations."""
Expand All @@ -57,7 +89,7 @@ def __init__(
ssl: bool = False,
port: int = 80,
verify_ssl: bool = True,
):
) -> None:
"""
:param host: the host of the HTTP server.
:param path: the path of the HTTP server.
Expand All @@ -77,5 +109,6 @@ def __init__(
self.ssl = ssl
self.verify_ssl = verify_ssl

@override(Policy)
def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
raise NotImplementedError
24 changes: 23 additions & 1 deletion tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from hopes.policy.policies import RandomPolicy
from hopes.policy.policies import RandomPolicy, RegressionBasedPolicy


class TestPolicies(unittest.TestCase):
Expand All @@ -20,6 +20,28 @@ def test_rnd_policy(self):
self.assertTrue(np.all(act_probs <= 1.0))
self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0))

def test_regression_policy(self):
# generate a random dataset of (obs, act) for target policy
num_actions = 3
num_obs = 5
num_samples = 100
obs = np.random.rand(num_samples, num_obs)
act = np.random.randint(num_actions, size=num_samples)

# create and fit a regression-based policy
reg_policy = RegressionBasedPolicy(obs=obs, act=act, regression_model="logistic")
reg_policy.fit()

# check if the policy returns valid log-likelihoods
new_obs = np.random.rand(10, num_obs)
act_probs = reg_policy.compute_action_probs(obs=new_obs)

self.assertIsInstance(act_probs, np.ndarray)
self.assertEqual(act_probs.shape, (10, 3))
self.assertTrue(np.all(act_probs >= 0.0))
self.assertTrue(np.all(act_probs <= 1.0))
self.assertTrue(np.allclose(act_probs.sum(axis=1), 1.0))

def test_compute_action_probs(self):
rnd_policy = RandomPolicy(num_actions=3)
act_probs = rnd_policy.compute_action_probs(obs=np.random.rand(10, 5))
Expand Down

0 comments on commit ab6e18c

Please sign in to comment.