Skip to content

Commit

Permalink
Allow testing using CPU #17
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonardoEmili committed Mar 30, 2020
1 parent 9c82f16 commit 567c5ba
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions src/wikification_nn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@
"metadata": {
"id": "mMAhu2EIlQ40",
"colab_type": "code",
"outputId": "12fa9346-1126-4ae4-efa4-51f704700c44",
"outputId": "433215b9-260e-4adf-d8e9-1161c66d8d24",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 156
"height": 68
}
},
"source": [
Expand All @@ -99,18 +99,14 @@
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": 2,
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
"\n",
"Enter your authorization code:\n",
"··········\n",
"Mounted at /content/drive\n"
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
],
"name": "stdout"
}
Expand Down Expand Up @@ -150,6 +146,7 @@
" val_slice = 0.20\n",
" inner_sep = '_'\n",
" outer_sep = '|'\n",
" # label entities using links that appear at least a certain number of times in the training set\n",
" link_cutoff = 50\n",
"\n",
" def __init__(self, dataset):\n",
Expand Down Expand Up @@ -708,7 +705,7 @@
"metadata": {
"id": "t9ZVZE-J7cDB",
"colab_type": "code",
"outputId": "a3e5cff7-ca16-4a17-cbf9-89dd0346d6d5",
"outputId": "457b41d1-147f-427d-929d-93136bacd0af",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
Expand Down Expand Up @@ -739,13 +736,13 @@
" device=\"cuda\"\n",
")\n",
"\n",
"if not torch.cuda.is_available():\n",
"if args.device == \"cpu\" or not torch.cuda.is_available():\n",
" print(\"Running model using CPU\")\n",
" args.device = 'cpu'\n",
"else:\n",
" print(\"Running model on GPU:\", torch.cuda.get_device_name(0), \"x\" + str(torch.cuda.device_count()))"
],
"execution_count": 9,
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -908,7 +905,7 @@
"metadata": {
"id": "YHjabSqcDTMF",
"colab_type": "code",
"outputId": "825ff553-9a0d-4860-fc65-cfecf860106f",
"outputId": "f9dab219-f3bb-43f7-ad81-12cbbba2215f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 782
Expand All @@ -919,7 +916,7 @@
"if args.save_model:\n",
" save_model(model, train_state, vocabulary, args)"
],
"execution_count": 12,
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
Expand Down Expand Up @@ -963,8 +960,8 @@
"colab": {}
},
"source": [
"# Load the model which was previously trained\n",
"state = torch.load(args.save_dir + args.model_state_file)\n",
"# Load the model which was previously trained on a GPU either using CPU or GPU\n",
"state = torch.load(args.save_dir + args.model_state_file, map_location=torch.device(args.device)) if args.device == \"cpu\" else torch.load(args.save_dir + args.model_state_file)\n",
"vocabulary = state['vocabulary']\n",
"model = BiLSTM(vocab_size=vocabulary.source_size(), embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, target_size=vocabulary.target_size(), batch_size=args.inference_batch_size, device=args.device).to(args.device)\n",
"model.load_state_dict(state['state_dict'])\n",
Expand Down Expand Up @@ -997,11 +994,11 @@
"metadata": {
"id": "jDMKjba6qVeb",
"colab_type": "code",
"outputId": "cdc595e4-99c0-4a6b-d969-07b9d3221f8f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 57
},
"outputId": "7f25aee9-b821-47b8-a967-e56904983a4c"
}
},
"source": [
"# Just an example sample sentence randomly picked from the test set\n",
Expand All @@ -1013,13 +1010,13 @@
"y = annotate(sample_sentence, stopwords, args)\n",
"display(HTML(y))"
],
"execution_count": 15,
"execution_count": 39,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<h4>As a result , the <a href=\"https://en.wikipedia.org/wiki/United_kingdom\" target=\"_blank\">british</a> sphere of influence had effectively expanded from the <a href=\"https://en.wikipedia.org/wiki/Indian_subcontinent\" target=\"_blank\">indian subcontinent</a> into <a href=\"https://en.wikipedia.org/wiki/Tibet\" target=\"_blank\">tibet</a> , although <a href=\"https://en.wikipedia.org/wiki/Tibet\" target=\"_blank\">tibet</a> remained nominally under the sovereignty of the <a href=\"https://en.wikipedia.org/wiki/Qing_dynasty\" target=\"_blank\">qing dynasty</a> of <a href=\"https://en.wikipedia.org/wiki/China\" target=\"_blank\">china</a> .</h4>"
"<h4>A <a href=\"https://en.wikipedia.org/wiki/Huguenot\" target=\"_blank\">huguenot</a> and <a href=\"https://en.wikipedia.org/wiki/Officer\" target=\"_blank\">officer</a> under <a href=\"https://en.wikipedia.org/wiki/Admiral\" target=\"_blank\">admiral</a> gaspard de coligny , ribault led an expedition to the <a href=\"https://en.wikipedia.org/wiki/New_world\" target=\"_blank\">new world</a> in 1562 that founded the outpost of charlesfort on parris <a href=\"https://en.wikipedia.org/wiki/Island\" target=\"_blank\">island</a> in present-day <a href=\"https://en.wikipedia.org/wiki/South_carolina\" target=\"_blank\">south carolina</a> .</h4>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down

0 comments on commit 567c5ba

Please sign in to comment.