diff --git a/handyrl/train.py b/handyrl/train.py index dd0b257c..a594a485 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -24,7 +24,7 @@ from .environment import prepare_env, make_env from .util import map_r, bimap_r, trimap_r, rotate -from .model import to_torch, to_gpu, ModelWrapper +from .model import to_torch, to_gpu from .losses import compute_target from .connection import MultiProcessJobExecutor from .worker import WorkerCluster, WorkerServer @@ -330,10 +330,9 @@ def __init__(self, args, model): self.update_flag = False self.update_queue = queue.Queue(maxsize=1) - self.wrapped_model = ModelWrapper(self.model) - self.trained_model = self.wrapped_model + self.trained_model = self.model if self.gpu > 1: - self.trained_model = nn.DataParallel(self.wrapped_model) + self.trained_model = nn.DataParallel(self.model) def update(self): self.update_flag = True @@ -354,7 +353,9 @@ def train(self): batch = self.batcher.batch() batch_size = batch['value'].size(0) player_count = batch['value'].size(2) - hidden = self.wrapped_model.init_hidden([batch_size, player_count]) + hidden = None + if hasattr(self.model, 'init_hidden'): + hidden = self.model.init_hidden([batch_size, player_count]) if self.gpu > 0: batch = to_gpu(batch) hidden = to_gpu(hidden)