Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Add word2vec #12

Merged
merged 8 commits into from
Jan 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions word2vec/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data/train.list
data/test.list
data/simple-examples*
462 changes: 461 additions & 1 deletion word2vec/README.md

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions word2vec/calculate_dis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
python calculate_dis.py DICTIONARYTXT FEATURETXT

Required arguments:
DICTIONARYTXT the dictionary generated in dataprovider
FEATURETXT the text format word feature, one line for one word
"""

import numpy as np
from argparse import ArgumentParser


def load_dict(fdict):
words = [line.strip() for line in fdict.readlines()]
dictionary = dict(zip(words, xrange(len(words))))
return dictionary


def load_emb(femb):
feaBank = []
flag_firstline = True
for line in femb:
if flag_firstline:
flag_firstline = False
continue
fea = np.array([float(x) for x in line.strip().split(',')])
normfea = fea * 1.0 / np.linalg.norm(fea)
feaBank.append(normfea)
return feaBank


def calcos(id1, id2, Fea):
f1 = Fea[id1]
f2 = Fea[id2]
return np.dot(f1.transpose(), f2)


def get_wordidx(w, Dict):
if w not in Dict:
print 'ERROR: %s not in the dictionary' % w
return -1
return Dict[w]


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('dict', help='dictionary file')
parser.add_argument('fea', help='feature file')
args = parser.parse_args()

with open(args.dict) as fdict:
word_dict = load_dict(fdict)

with open(args.fea) as ffea:
word_fea = load_emb(ffea)

while True:
w1, w2 = raw_input("please input two words: ").split()
w1_id = get_wordidx(w1, word_dict)
w2_id = get_wordidx(w2, word_dict)
if w1_id == -1 or w2_id == -1:
continue
print 'similarity: %s' % (calcos(w1_id, w2_id, word_fea))
7 changes: 7 additions & 0 deletions word2vec/data/getdata.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
set -e

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set -e

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add .gitignore for getdata.sh generated files.

image

wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在第一行放#!/bin/bash

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

tar -zxf simple-examples.tgz
echo `pwd`/simple-examples/data/ptb.train.txt > train.list
echo `pwd`/simple-examples/data/ptb.valid.txt > test.list
63 changes: 63 additions & 0 deletions word2vec/dataprovider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *
import collections
import logging
import pdb

logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s', )
logger = logging.getLogger('paddle')
logger.setLevel(logging.INFO)

N = 5 # Ngram
cutoff = 50 # select words with frequency > cutoff to dictionary


def build_dict(ftrain, fdict):
sentences = []
with open(ftrain) as fin:
for line in fin:
line = ['<s>'] + line.strip().split() + ['<e>']
sentences += line
wordfreq = collections.Counter(sentences)
wordfreq = filter(lambda x: x[1] > cutoff, wordfreq.items())
dictionary = sorted(wordfreq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
for word in words:
print >> fdict, word
word_idx = dict(zip(words, xrange(len(words))))
logger.info("Dictionary size=%s" % len(words))
return word_idx


def initializer(settings, srcText, dictfile, **xargs):
with open(dictfile, 'w') as fdict:
settings.dicts = build_dict(srcText, fdict)
input_types = []
for i in xrange(N):
input_types.append(integer_value(len(settings.dicts)))
settings.input_types = input_types


@provider(init_hook=initializer)
def process(settings, filename):
UNKID = settings.dicts['<unk>']
with open(filename) as fin:
for line in fin:
line = ['<s>'] * (N - 1) + line.strip().split() + ['<e>']
line = [settings.dicts.get(w, UNKID) for w in line]
for i in range(N, len(line) + 1):
yield line[i - N:i]
158 changes: 158 additions & 0 deletions word2vec/format_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
python format_convert.py --b2t -i INPUT -o OUTPUT -d DIM
python format_convert.py --t2b -i INPUT -o OUTPUT

Options:
-h, --help show this help message and exit
--b2t convert parameter file of embedding model from binary to text
--t2b convert parameter file of embedding model from text to binary
-i INPUT input parameter file name
-o OUTPUT output parameter file name
-d DIM dimension of parameter
"""
from optparse import OptionParser
import struct


def binary2text(input, output, paraDim):
"""
Convert a binary parameter file of embedding model to be a text file.
input: the name of input binary parameter file, the format is:
1) the first 16 bytes is filehead:
version(4 bytes): version of paddle, default = 0
floatSize(4 bytes): sizeof(float) = 4
paraCount(8 bytes): total number of parameter
2) the next (paraCount * 4) bytes is parameters, each has 4 bytes
output: the name of output text parameter file, for example:
0,4,32156096
-0.7845433,1.1937413,-0.1704215,...
0.0000909,0.0009465,-0.0008813,...
...
the format is:
1) the first line is filehead:
version=0, floatSize=4, paraCount=32156096
2) other lines print the paramters
a) each line prints paraDim paramters splitted by ','
b) there is paraCount/paraDim lines (embedding words)
paraDim: dimension of parameters
"""
fi = open(input, "rb")
fo = open(output, "w")
"""
"""
version, floatSize, paraCount = struct.unpack("iil", fi.read(16))
newHead = ','.join([str(version), str(floatSize), str(paraCount)])
print >> fo, newHead

bytes = 4 * int(paraDim)
format = "%df" % int(paraDim)
context = fi.read(bytes)
line = 0

while context:
numbers = struct.unpack(format, context)
lst = []
for i in numbers:
lst.append('%8.7f' % i)
print >> fo, ','.join(lst)
context = fi.read(bytes)
line += 1
fi.close()
fo.close()
print "binary2text finish, total", line, "lines"


def get_para_count(input):
"""
Compute the total number of embedding parameters in input text file.
input: the name of input text file
"""
numRows = 1
paraDim = 0
with open(input) as f:
line = f.readline()
paraDim = len(line.split(","))
for line in f:
numRows += 1
return numRows * paraDim


def text2binary(input, output, paddle_head=True):
"""
Convert a text parameter file of embedding model to be a binary file.
input: the name of input text parameter file, for example:
-0.7845433,1.1937413,-0.1704215,...
0.0000909,0.0009465,-0.0008813,...
...
the format is:
1) it doesn't have filehead
2) each line stores the same dimension of parameters,
the separator is commas ','
output: the name of output binary parameter file, the format is:
1) the first 16 bytes is filehead:
version(4 bytes), floatSize(4 bytes), paraCount(8 bytes)
2) the next (paraCount * 4) bytes is parameters, each has 4 bytes
"""
fi = open(input, "r")
fo = open(output, "wb")

newHead = struct.pack("iil", 0, 4, get_para_count(input))
fo.write(newHead)

count = 0
for line in fi:
line = line.strip().split(",")
for i in range(0, len(line)):
binary_data = struct.pack("f", float(line[i]))
fo.write(binary_data)
count += 1
fi.close()
fo.close()
print "text2binary finish, total", count, "lines"


def main():
"""
Main entry for running format_convert.py
"""
usage = "usage: \n" \
"python %prog --b2t -i INPUT -o OUTPUT -d DIM \n" \
"python %prog --t2b -i INPUT -o OUTPUT"
parser = OptionParser(usage)
parser.add_option(
"--b2t",
action="store_true",
help="convert parameter file of embedding model from binary to text")
parser.add_option(
"--t2b",
action="store_true",
help="convert parameter file of embedding model from text to binary")
parser.add_option(
"-i", action="store", dest="input", help="input parameter file name")
parser.add_option(
"-o", action="store", dest="output", help="output parameter file name")
parser.add_option(
"-d", action="store", dest="dim", help="dimension of parameter")
(options, args) = parser.parse_args()
if options.b2t:
binary2text(options.input, options.output, options.dim)
if options.t2b:
text2binary(options.input, options.output)


if __name__ == '__main__':
main()
Binary file added word2vec/image/2d_similarity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added word2vec/image/cbow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added word2vec/image/ngram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added word2vec/image/nnlm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added word2vec/image/sentence_emb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added word2vec/image/skipgram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 84 additions & 0 deletions word2vec/ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *

import math

#################### Data Configure ####################
args = {
'srcText': 'data/simple-examples/data/ptb.train.txt',
'dictfile': 'data/vocabulary.txt'
}
define_py_data_sources2(
train_list="data/train.list",
test_list="data/test.list",
module="dataprovider",
obj="process",
args=args)

settings(
batch_size=100, regularization=L2Regularization(8e-4), learning_rate=3e-3)

dictsize = 1953
embsize = 32
hiddensize = 256

firstword = data_layer(name="firstw", size=dictsize)
secondword = data_layer(name="secondw", size=dictsize)
thirdword = data_layer(name="thirdw", size=dictsize)
fourthword = data_layer(name="fourthw", size=dictsize)
nextword = data_layer(name="fifthw", size=dictsize)


# construct word embedding for each datalayer
def wordemb(inlayer):
wordemb = table_projection(
input=inlayer,
size=embsize,
param_attr=ParamAttr(
name="_proj",
initial_std=0.001,
learning_rate=1,
l2_rate=0, ))
return wordemb


Efirst = wordemb(firstword)
Esecond = wordemb(secondword)
Ethird = wordemb(thirdword)
Efourth = wordemb(fourthword)

# concatentate Ngram embeddings into context embedding
contextemb = concat_layer(input=[Efirst, Esecond, Ethird, Efourth])
hidden1 = fc_layer(
input=contextemb,
size=hiddensize,
act=SigmoidActivation(),
layer_attr=ExtraAttr(drop_rate=0.5),
bias_attr=ParamAttr(learning_rate=2),
param_attr=ParamAttr(
initial_std=1. / math.sqrt(embsize * 8), learning_rate=1))

# use context embedding to predict nextword
predictword = fc_layer(
input=hidden1,
size=dictsize,
bias_attr=ParamAttr(learning_rate=2),
act=SoftmaxActivation())

cost = classification_cost(input=predictword, label=nextword)

# network input and output
outputs(cost)
Loading