Skip to content

Commit

Permalink
fix 2.0.1 crf bug (PaddlePaddle#464)
Browse files Browse the repository at this point in the history
* 1.fix 2.0.1 slice bug; 2. cast int32 label which returned in Windows to int64

* change input_id, token_type_id dtype: int64->int32
  • Loading branch information
joey12300 committed May 31, 2021
1 parent 880b4e2 commit 4d3905b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions examples/information_extraction/waybill_ie/run_bigru_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def predict(model, data_loader, ds, label_vocab):
test_ds.map(trans_func)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=word_vocab.get('OOV', 0)), # token_ids
Stack(), # seq_len
Pad(axis=0, pad_val=label_vocab.get('O', 0)) # label_ids
Pad(axis=0, pad_val=word_vocab.get('OOV', 0), dtype='int32'), # token_ids
Stack(dtype='int64'), # seq_len
Pad(axis=0, pad_val=label_vocab.get('O', 0), dtype='int64') # label_ids
): fn(samples)

train_loader = paddle.io.DataLoader(
Expand Down
8 changes: 4 additions & 4 deletions examples/information_extraction/waybill_ie/run_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def predict(model, data_loader, ds, label_vocab):

ignore_label = -1
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids
Stack(), # seq_len
Pad(axis=0, pad_val=ignore_label) # labels
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'), # token_type_ids
Stack(dtype='int64'), # seq_len
Pad(axis=0, pad_val=ignore_label, dtype='int64') # labels
): fn(samples)

train_loader = paddle.io.DataLoader(
Expand Down
8 changes: 4 additions & 4 deletions examples/information_extraction/waybill_ie/run_ernie_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def predict(model, data_loader, ds, label_vocab):
test_ds.map(trans_func)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids
Stack(), # seq_len
Pad(axis=0, pad_val=label_vocab.get("O", 0)) # labels
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'), # input_ids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'), # token_type_ids
Stack(dtype='int64'), # seq_len
Pad(axis=0, pad_val=label_vocab.get("O", 0), dtype='int64') # labels
): fn(samples)

train_loader = paddle.io.DataLoader(
Expand Down

0 comments on commit 4d3905b

Please sign in to comment.