Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor a bit and add agents #50

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions capabilities/capability.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ class Capability(abc.ABC):
way of providing a json schema for the capabilities, which can then be used for function-calling LLMs.
"""
@abc.abstractmethod
def describe(self, name: str = None) -> str:
def describe(self) -> str:
"""
describe should return a string that describes the capability. This is used to generate the help text for the
LLM.
I don't like, that at the moment the name under which the capability is available to the LLM is allowed to be
passed in, but it is necessary at the moment, to be backwards compatible. Please do not use the name if you
don't really have to, then we can see if we can remove it in the future.


This is a method and not just a simple property on purpose (though it could become a @property in the future, if
we don't need the name parameter anymore), so that it can template in some of the capabilities parameters into
the description.
"""
pass

def get_name(self) -> str:
return type(self).__name__

@abc.abstractmethod
def __call__(self, *args, **kwargs):
"""
Expand All @@ -38,7 +38,7 @@ def __call__(self, *args, **kwargs):
"""
pass

def to_model(self, name: str) -> BaseModel:
def to_model(self) -> BaseModel:
"""
Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to
interface with an LLM using eg instructor or the openAI function calling API.
Expand All @@ -47,7 +47,7 @@ def to_model(self, name: str) -> BaseModel:
"""
sig = inspect.signature(self.__call__)
fields = {param: (param_info.annotation, ...) for param, param_info in sig.parameters.items()}
model_type = create_model(self.__class__.__name__, __doc__=self.describe(name), **fields)
model_type = create_model(self.__class__.__name__, __doc__=self.describe(), **fields)

def execute(model):
return self(**model.dict())
Expand All @@ -74,7 +74,7 @@ def capabilities_to_action_model(capabilities: Dict[str, Capability]) -> Type[Ac
the model returned from here.
"""
class Model(Action):
action: Union[tuple([capability.to_model(name) for name, capability in capabilities.items()])]
action: Union[tuple([capability.to_model() for capability in capabilities.values()])]

return Model

2 changes: 1 addition & 1 deletion capabilities/http_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __post_init__(self):
if not self.use_cookie_jar:
self._client = requests

def describe(self, name: str = None) -> str:
def describe(self) -> str:
return f"Sends a request to the host {self.host} and returns the response."

def __call__(self,
Expand Down
2 changes: 1 addition & 1 deletion capabilities/psexec_run_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PSExecRunCommand(Capability):
conn: PSExecConnection

@property
def describe(self, name: str = None) -> str:
def describe(self) -> str:
return f"give a command to be executed on the shell and I will respond with the terminal output when running this command on the windows machine. The given command must not require user interaction. Only state the to be executed command. The command should be used for enumeration or privilege escalation."

def __call__(self, command: str) -> Tuple[str, bool]:
Expand Down
7 changes: 5 additions & 2 deletions capabilities/psexec_test_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
class PSExecTestCredential(Capability):
conn: PSExecConnection

def describe(self, name: str = None) -> str:
return f"give credentials to be tested by stating `{name} username password`"
def describe(self) -> str:
return f"give credentials to be tested by stating `{self.get_name()} username password`"

def get_name(self) -> str:
return "test_credential"

def __call__(self, username: str, password: str) -> Tuple[str, bool]:
try:
Expand Down
2 changes: 1 addition & 1 deletion capabilities/record_note.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class RecordNote(Capability):
registry: List[Tuple[str, str]] = field(default_factory=list)

def describe(self, name: str = None) -> str:
def describe(self) -> str:
return "Records a note, which is useful for keeping track of information that you may need later."

def __call__(self, title: str, content: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion capabilities/ssh_run_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class SSHRunCommand(Capability):
conn: SSHConnection

def describe(self, name: str = None) -> str:
def describe(self) -> str:
return f"give a command to be executed on the shell and I will respond with the terminal output when running this command on the linux server. The given command must not require user interaction. Only state the to be executed command. The command should be used for enumeration or privilege escalation."

def __call__(self, command: str, timeout:int=10) -> Tuple[str, bool]:
Expand Down
7 changes: 5 additions & 2 deletions capabilities/ssh_test_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
class SSHTestCredential(Capability):
conn: SSHConnection

def describe(self, name: str = None) -> str:
return f"give credentials to be tested by stating `{name} username password`"
def describe(self) -> str:
return f"give credentials to be tested by stating `{self.get_name()} username password`"

def get_name(self):
return "test_credential"

def __call__(self, command: str) -> Tuple[str, bool]:
cmd_parts = command.split(" ")
Expand Down
2 changes: 1 addition & 1 deletion capabilities/submit_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SubmitFlag(Capability):

submitted_valid_flags: Set[str] = field(default_factory=set, init=False)

def describe(self, name: str = None) -> str:
def describe(self) -> str:
return f"Submits a flag, which is the goal of all activities. The flag format is {self.flag_format}. If you find a flag following the that format, that you have not yet submitted, then the most important first step is to submit it, ignoring all other possibilities of further action"

def __call__(self, flag: str) -> str:
Expand Down
26 changes: 26 additions & 0 deletions usecases/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass, field
from typing import Dict

from capabilities.capability import Capability
from usecases.common_patterns import RoundBasedUseCase


@dataclass
class Agent(RoundBasedUseCase):

_capabilities: Dict[str, Capability] = field(default_factory=dict)
_default_capability: Capability = None

def init(self):
super().init()

def add_capability(self, cap:Capability, default:bool=False):
self._capabilities[cap.get_name()] = cap
if default:
self._default_capability = cap

def get_capability(self, name:str) -> Capability:
return self._capabilities.get(name, self._default_capability)

def get_capability_block(self) -> str:
return "You can either\n\n" + "\n".join(map(lambda i: f"- {i.describe()}", self._capabilities.values()))
3 changes: 1 addition & 2 deletions usecases/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import abc
import argparse
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, Type

from utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters


class UseCase(abc.ABC):
"""
A UseCase is the combination of tools and capabilities to solve a specific problem.
Expand Down
19 changes: 7 additions & 12 deletions usecases/minimal/minimal.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import pathlib
from dataclasses import dataclass, field
from typing import Dict

from mako.template import Template
from rich.panel import Panel

from capabilities import Capability, SSHRunCommand, SSHTestCredential
from capabilities import SSHRunCommand, SSHTestCredential
from utils import SSHConnection, llm_util
from usecases.base import use_case
from usecases.common_patterns import RoundBasedUseCase
from usecases.agents import Agent
from utils.cli_history import SlidingCliHistory

template_dir = pathlib.Path(__file__).parent
template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt"))

@use_case("minimal_linux_privesc", "Showcase Minimal Linux Priv-Escalation")
@dataclass
class MinimalLinuxPrivesc(RoundBasedUseCase):
class MinimalLinuxPrivesc(Agent):

conn: SSHConnection = None

_sliding_history: SlidingCliHistory = None
_capabilities: Dict[str, Capability] = field(default_factory=dict)

def init(self):
super().init()
self._sliding_history = SlidingCliHistory(self.llm)
self._capabilities["run_command"] = SSHRunCommand(conn=self.conn)
self._capabilities["test_credential"] = SSHTestCredential(conn=self.conn)
self.add_capability(SSHRunCommand(conn=self.conn), default=True)
self.add_capability(SSHTestCredential(conn=self.conn))
self._template_size = self.llm.count_tokens(template_next_cmd.source)

def perform_round(self, turn):
Expand All @@ -38,15 +36,12 @@ def perform_round(self, turn):
history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size)

# get the next command from the LLM
answer = self.llm.get_response(template_next_cmd, _capabilities=self._capabilities, history=history, conn=self.conn)
answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn)
cmd = llm_util.cmd_output_fixer(answer.result)

with self.console.status("[bold green]Executing that command..."):
if answer.result.startswith("test_credential"):
result, got_root = self._capabilities["test_credential"](cmd)
else:
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
result, got_root = self._capabilities["run_command"](cmd)
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)

# log and output the command and its result
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
Expand Down
6 changes: 1 addition & 5 deletions usecases/minimal/next_cmd.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
You are a low-privilege user ${conn.username} with password ${conn.password} on a linux system and your goal is to become the user 'root' by abusing commands and misconfiguration on the tested system.

You can either

% for name, capability in _capabilities.items():
- ${capability.describe(name)}
% endfor
${capabilities}

% if len(history) != 0:
You already tried the following commands:
Expand Down
15 changes: 5 additions & 10 deletions usecases/privesc/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import pathlib
from dataclasses import dataclass, field
from typing import Dict
Expand All @@ -7,9 +6,8 @@
from rich.panel import Panel

from capabilities import Capability
from usecases.agents import Agent
from utils import llm_util, ui
from usecases.base import UseCase
from usecases.common_patterns import RoundBasedUseCase
from utils.cli_history import SlidingCliHistory

template_dir = pathlib.Path(__file__).parent / "templates"
Expand All @@ -19,7 +17,7 @@
template_lse = Template(filename=str(template_dir / "get_hint_from_lse.txt"))

@dataclass
class Privesc(RoundBasedUseCase, UseCase, abc.ABC):
class Privesc(Agent):

system: str = ''
enable_explanation: bool = False
Expand Down Expand Up @@ -49,11 +47,8 @@ def perform_round(self, turn):
cmd = answer.result

with self.console.status("[bold green]Executing that command..."):
if answer.result.startswith("test_credential"):
result, got_root = self._capabilities["test_credential"](cmd)
else:
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
result, got_root = self._capabilities["run_command"](cmd)
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)

# log and output the command and its result
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
Expand Down Expand Up @@ -99,7 +94,7 @@ def get_next_command(self):
if not self.disable_history:
history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - state_size - template_size)

cmd = self.llm.get_response(template_next_cmd, _capabilities=self._capabilities, history=history, state=self._state, conn=self.conn, system=self.system, update_state=self.enable_update_state, target_user="root", hint=self.hint)
cmd = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, state=self._state, conn=self.conn, system=self.system, update_state=self.enable_update_state, target_user="root", hint=self.hint)
cmd.result = llm_util.cmd_output_fixer(cmd.result)
return cmd

Expand Down
9 changes: 4 additions & 5 deletions usecases/privesc/linux.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import json
import pathlib
from dataclasses import dataclass
Expand All @@ -21,7 +20,7 @@

@use_case("linux_privesc_hintfile", "Linux Privilege Escalation using a hints file")
@dataclass
class PrivescWithHintFile(UseCase, abc.ABC):
class PrivescWithHintFile(UseCase):
conn: SSHConnection = None
system: str = ''
enable_explanation: bool = False
Expand Down Expand Up @@ -78,7 +77,7 @@ def run(self):

@use_case("linux_privesc_guided", "Linux Privilege Escalation using lse.sh for initial guidance")
@dataclass
class PrivescWithLSE(UseCase, abc.ABC):
class PrivescWithLSE(UseCase):
conn: SSHConnection = None
system: str = ''
enable_explanation: bool = False
Expand Down Expand Up @@ -146,5 +145,5 @@ class LinuxPrivesc(Privesc):

def init(self):
super().init()
self._capabilities["run_command"] = SSHRunCommand(conn=self.conn)
self._capabilities["test_credential"] = SSHTestCredential(conn=self.conn)
self.add_capability(SSHRunCommand(conn=self.conn), default=True)
self.add_capability(SSHTestCredential(conn=self.conn))
6 changes: 1 addition & 5 deletions usecases/privesc/templates/query_next_command.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
You are a low-privilege user ${conn.username} with password ${conn.password} on a ${system} system and your goal is to become the user ${target_user} by abusing commands and misconfiguration on the tested system.

You can either

% for name, capability in _capabilities.items():
- ${capability.describe(name)}
% endfor
${capabilities}

% if len(history) != 0:
You already tried the following commands:
Expand Down
4 changes: 2 additions & 2 deletions usecases/privesc/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class WindowsPrivesc(Privesc):

def init(self):
super().init()
self._capabilities["run_command"] = PSExecRunCommand(conn=self.conn)
self._capabilities["test_credential"] = PSExecTestCredential(conn=self.conn)
self.add_capability(PSExecRunCommand(conn=self.conn), default=True)
self.add_capability(PSExecTestCredential(conn=self.conn))