Skip to content

Commit

Permalink
simplify dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
alon-albalak committed Mar 20, 2022
1 parent 2acbbee commit f2989fc
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 15 deletions.
14 changes: 7 additions & 7 deletions TLiDB/data_loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import random

def get_train_loader(dataset, batch_size, config, **loader_kwargs):
def get_train_loader(dataset, batch_size, model_type, **loader_kwargs):
"""
Constructs and return the data loader for training
Args:
Expand All @@ -12,15 +12,15 @@ def get_train_loader(dataset, batch_size, config, **loader_kwargs):
Returns:
- data_loader (DataLoader): The data loader for training
"""
if dataset.task_metadata['type'] == "multiple_choice" and config.model_type == "Encoder":
if dataset.task_metadata['type'] == "multiple_choice" and model_type == "Encoder":
# Encoder-only models split multiple choice into num_choices samples
# so we need to downscale the batch_size accordingly
batch_size = batch_size // dataset.task_metadata['num_choices']
if loader_kwargs['num_workers'] > 1:
loader_kwargs['pin_memory'] = True
return DataLoader(dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)

def get_eval_loader(dataset, batch_size, config, **loader_kwargs):
def get_eval_loader(dataset, batch_size, model_type, **loader_kwargs):
"""
Constructs and return the data loader for evaluation
Args:
Expand All @@ -30,19 +30,19 @@ def get_eval_loader(dataset, batch_size, config, **loader_kwargs):
Returns:
- data_loader (DataLoader): The data loader for evaluation
"""
if dataset.task_metadata['type'] == "multiple_choice" and config.model_type == "Encoder":
if dataset.task_metadata['type'] == "multiple_choice" and model_type == "Encoder":
# Encoder-only models split multiple choice into num_choices samples
# so we need to downscale the batch_size accordingly
batch_size = batch_size // dataset.task_metadata['num_choices']
if loader_kwargs['num_workers'] > 1:
loader_kwargs['pin_memory'] = True
return DataLoader(dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)

def get_loader(split):
def get_dataloader(split, dataset, batch_size, model_type, **loader_kwargs):
if split == 'train':
return get_train_loader
return get_train_loader(dataset, batch_size, model_type, **loader_kwargs)
else:
return get_eval_loader
return get_eval_loader(dataset, batch_size, model_type, **loader_kwargs)

class TLiDB_DataLoader:
"""
Expand Down
4 changes: 0 additions & 4 deletions examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def transform_inputs(self, inputs):
"""Only tokenizes inputs"""
tokenized_inputs = self.tokenizer(inputs, padding="longest", pad_to_multiple_of=8,
truncation=True, return_tensors="pt")
#FIXME check if inputs['attention_mask'] should be added?
return tokenized_inputs

def transform_outputs(self, outputs):
Expand All @@ -47,16 +46,13 @@ def transform_outputs(self, outputs):
# replace pad tokens by -100
label_ids = tokenized_outputs.input_ids
label_ids[label_ids == self.tokenizer.pad_token_id] = -100
#FIXME check if inputs['labels'] should replace pad to -100?
return label_ids

def generate(self, input_ids, **kwargs):
pred_tokens = self.model.generate(input_ids=input_ids, **kwargs)
#FIXME check if inputs['attention_mask'] should be added?
return pred_tokens

def batch_decode(self, tokens):
#FIXME check the usage of this function.
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)

def initialize_model(config):
Expand Down
1 change: 0 additions & 1 deletion examples/results_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from email.mime import base
import os
import csv
import sys
Expand Down
5 changes: 2 additions & 3 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import torch

from TLiDB.datasets.get_dataset import get_dataset
from TLiDB.data_loaders.data_loaders import get_loader
from TLiDB.data_loaders.data_loaders import get_dataloader
from TLiDB.metrics.initializer import get_metric_computer

def load_datasets_split(split, tasks, datasets, config):
split_datasets = {"datasets":[], "loaders":[], "metrics":[]}
get_data_loader = get_loader(split)
for t, d in zip(tasks, datasets):
cur_dataset = get_dataset(dataset=d,task=t,dataset_folder=config.data_dir,
model_type=config.model_type,
Expand All @@ -20,7 +19,7 @@ def load_datasets_split(split, tasks, datasets, config):
cur_dataset.random_subsample(config.frac)

split_datasets["datasets"].append(cur_dataset)
split_datasets["loaders"].append(get_data_loader(cur_dataset, config.gpu_batch_size, config, collate_fn=cur_dataset.collate, num_workers=config.num_workers))
split_datasets["loaders"].append(get_dataloader(split, cur_dataset, config.gpu_batch_size, config, collate_fn=cur_dataset.collate, num_workers=config.num_workers))
split_datasets["metrics"].append(get_metric_computer(cur_dataset.metrics, **cur_dataset.metric_kwargs))
return split_datasets

Expand Down

0 comments on commit f2989fc

Please sign in to comment.