diff --git a/MANIFEST.in b/MANIFEST.in index 04e751d..5fd80fc 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -21,6 +21,6 @@ include scripts/wordhash_regressor.py include scripts/wordseq_regressor.py include scripts/wordvec_regressor.py include scripts/classify_airline_sentiment.py -include scripts/wordbag_regressor_spark.py +include scripts/backends_benchmark.py include data/Tweets.csv \ No newline at end of file diff --git a/README.rst b/README.rst index 0da3082..373835a 100644 --- a/README.rst +++ b/README.rst @@ -1,85 +1,155 @@ -Wordbatch -========= +============= +Wordbatch 1.4 +============= -Parallel text feature extraction for machine learning. +Overview +======== -Wordbatch produces parallel feature extraction from raw text data for uses such as deep learning and text analytics. The most basic use for Wordbatch is as a drop-in replacement for the more basic non-parallelized extraction available in toolkits such as Scikit-learn, TfLearn and Keras. Wordbatch additionally provides customizable preprocessing and feature extractors that improve predictive performance. +Python library for distributed AI processing pipelines, using swappable scheduler backends. -Unlike text preprocessing in these toolkits that either deal with text as a single batch or as a stream, Wordbatch works best with large minibatches of text data. Wordbatch internally stores per-batch statistics of the data, and applies these for uses such as dictionary selection, spelling correction, and online IDF weighting. The larger the batches, the better choices Wordbatch can make in extracting features. +Wordbatch parallelizes task pipelines as minibatches processed by a chosen scheduler backend. This allows +the user to develop AI programs on a local workstation or laptop, and scale the same +solution on a cluster or the cloud, simply by changing the pipeline backend to a distributed scheduler such as Spark, +Dask and Ray. A backend can be chosen based on performance characteristics on a particular task, and swapped for +different situations. For example, an AI model can be trained using a distributed backend, and then debugged or +deployed using a single serial process. -The current text preprocessing options include passing any function as text normalization to be parallelized, a constant-time adaptive version of Norvig spelling correction, and passing any function for parallel stemming. - -Currently four basic feature extractor classes are provided: - -- WordHash is simply the Scikit-learn HashingVectorizer wrapped with the Wordbatch parallelization, providing multiplied processing speeds -- WordBag is a flexible alternative to Wordhash, providing cababilities missing from Scikit-learn, such as IDF and per n-gram order weighting of hashed features, windowed and distance-weighted polynomial interactions, and more transforms for counts. -- WordSeq provides sequences of word integers, as used by the deep learning toolkits for input into LSTM models. -- WordVec provides embedding transforms from words into wordvectors - -A list of extractors can be defined. For example, word vector sequences can be projected into per-document vectors, and concatenated with the vectors from other word vector embeddings. - -Four basic OpenMP-parallelized L1&L2-regularized online learning models are provided, for single-label regression and classification: - -- FTRL : Linear model Proximal-FTRL that has become the most popular algorithm for online learning of linear models in Kaggle competions. The Cython-optimized implementation should be the fastest available version of FTRL. -- FM_FTRL : Factorization Machines. Linear effects estimated with FTRL and factor effects estimated with adaptive SGD. Prediction and estimation multithreaded across factors. -- NN_Relu_H1 : Neural Network with 1 hidden layer and Rectified Linear Unit activations, estimated with adaptive SGD. Prediction and estimation multithreaded across hidden layer. -- NN_Relu_H2: Neural Network with 2 hidden layers and Rectified Linear Unit activations, estimated with adaptive SGD. Prediction multithreaded across 2nd hidden layer, estimation across 1st hidden layer outputs. - -The adaptive SGD optimizer works like Adagrad, but pools the adaptive learning rates across hidden nodes using the same feature. This makes learning more robust and requires less memory. - -Wordbatch is written with Cython, and uses concurrent threading, multiprocessing and OpenMP parallelization for circumventing the Python GIL. License is GNU GPL 2.0, and less restrictive licenses are available on request. +The library is organized around the orchestrator class Batcher, and Sklearn-compatible components, +split into Pipelines, Transformers, Extractors and Models. These extend the Scikit-learn API with a +fit_partial()-method, that enables transformers and models to be used in a streaming fashion. +The current set of components has been developed mostly for text processing tasks, but components for other domains +can be developed based on the available classes. Requirements ============ -Linux/Windows. Python 2.7 / 3.6 / 3.7 +Linux / Windows / macOS. Python 3.6 / 3.7 Installation ============ pip install wordbatch +macOS: compile using GCC-7 (https://github.com/anttttti/Wordbatch/issues/1) + Getting started =============== -| #from sklearn.feature_extraction.text import HashingVectorizer -| #from sklearn.linear_model import * -| #vct= HashingVectorizer() -| #clf= SGDRegressor() -| -| import wordbatch | from wordbatch.models import FTRL | from wordbatch.extractors import WordBag -| wb= wordbatch.WordBatch(extractor=(WordBag, {"hash_ngrams":2, "hash_ngrams_weights":[0.5, -1.0], "hash_size":2**23, "norm":'l2', "tf":'log', "idf":50.0})) +| from wordbatch.pipelines import WordBatch +| from wordbatch.batcher import Batcher +| +| wb= WordBatch(extractor=WordBag(hash_ngrams=0, norm= 'l2', tf= 'binary', idf= 50.0), +| batcher=Batcher(backend="multiprocessing")) +| | clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=2 ** 25, iters=1) | -| train_texts= ["Cut down a tree with a herring? It can't be done.", "Don't say that word.", "How can we not say the word if you don't tell us what it is?"] +| train_texts= ["Cut down a tree with a herring? It can't be done.", +| "Don't say that word.", +| "How can we not say the word if you don't tell us what it is?"] | train_labels= [1, 0, 1] | test_texts= ["Wait! I said it! I said it! Ooh! I said it again!"] | -| clf.fit(wb.transform(train_texts), train_labels) -| preds= clf.predict(wb.transform(test_texts)) +| clf.fit(wb.fit_transform(train_texts), train_labels) +| print(clf.predict(wb.transform(test_texts))) +| +| import ray +| ray.init() +| wb.batcher.backend= "ray" +| wb.batcher.backend_handle= ray +| +| clf.fit(wb.fit_transform(train_texts), train_labels) +| print(clf.predict(wb.transform(test_texts))) + + +Components +========== + +Batcher +------- +Batcher orchestrates MapReduce processing of tasks using a backend, by splitting input data into separately processed +minibatches. Currently three local (serial, multiprocessing, Loky) and three distributed backends (Spark, Dask, +Ray) are supported. Some distributed backends will process the tasks concurrently as a graph of lazily evaluated +futures, with Batcher dynamically sending the graph for the backend to process. All three supported distributed +backends allow real time monitoring of the processing pipeline using the backend's own GUI. + + +Pipelines +--------- +Pipelines are classes that send functions, methods and classes to Batcher for processing. Unlike other components in +Wordbatch, pipelines contain a reference to Batcher, and are never referenced themselves in the calls sent to Batcher. +This prevents trying to serialize and send the backend handle itself. The simplest pipeline is Apply, +which processes a function or method over the input data row-by-row. WordBatch is a full complex pipeline for text +processing, with optional steps such as text normalization, spelling correction, stemming, feature extraction, and +LZ4-caching of results. + + +Transformers +------------ +Transformers are transformer classes extending the Scikit-learn API, by accepting a Batcher instance as argument +of fit and transform methods. Transformers won't store Batcher references, allowing the transformer objects to be sent +to distributed workers. This allows transformers to do MapReduce operations as part of its methods, for example +gathering a dictionary of words from data when fitting a Dictionary. The current set of transformers are +text-specific classes, such as Dictionary, Tokenizer and TextNormalizer. + + +Extractors +---------- +Extractors are transformer classes which don't directly call Batcher. Since extractors can't call Batcher directly, +they are mostly immutable and used for their transform() method calls distributed using a pipeline. The current set of +extractors is Cython-optimized, and aside from PandasHash intended for text feature extraction. These are: + +- WordHash is wrapper for Scikit-learn HashingVectorizer, extended with option for LZ4-caching +- WordBag is a flexible alternative to Wordhash, with options such as IDF and per n-gram order weighting +- WordSeq provides sequences of word integers, as used by deep learning language models +- WordVec embeds words into word vector representations +- PandasHash extracts hashed features from a Pandas DataFrame, similar to VowpalWabbit's feature extraction + + +Models +------ +Models are predictive models such as classifiers. Similar to extractors, they don't directly call Batcher, but are +Scikit-learn compatible and distributed using a pipeline if needed. Currently four +OpenMP-multithreaded L1&L2-regularized online learning models are provided, for single-label regression and +classification: + +- FTRL : Linear model Proximal-FTRL that has become the most popular algorithm for online learning of linear models in Kaggle competions. The Cython-optimized implementation should be the fastest available version of FTRL. +- FM_FTRL : Factorization Machines. Linear effects estimated with FTRL and factor effects estimated with adaptive SGD. Prediction and estimation multithreaded across factors. +- NN_Relu_H1 : Neural Network with 1 hidden layer and Rectified Linear Unit activations, estimated with adaptive SGD. Prediction and estimation multithreaded across hidden layer. +- NN_Relu_H2: Neural Network with 2 hidden layers and Rectified Linear Unit activations, estimated with adaptive SGD. Prediction multithreaded across 2nd hidden layer, estimation across 1st hidden layer outputs. +The adaptive SGD optimizer works like Adagrad, but pools the adaptive learning rates across hidden nodes using the same +feature. This makes learning more robust and requires less memory. FM_FTRL uses AVX2-optimization, so that processors +supporting AVX2 will run the factorization model up to four times faster. Example scripts =============== -The directory /scripts/ contains four scripts for demonstrating the basic extractors, and a Scikit-learn ensemble model to combine predictions. To run the scripts you should first install the dependencies: Keras, NLTK, TextBlob and Pandas. The scripts also use the TripAdvisor dataset (http://times.cs.uiuc.edu/~wang296/Data/) for training models, and the precomputed word embeddings glove.twitter.27B.100d and glove.6B.50d (http://nlp.stanford.edu/projects/glove/). The test data from Crowdflower Open data & Kaggle is provided in the /data directory. - -- wordhash_regressor.py shows wordbatch.extractors.WordHash, and feature extraction concurrent with file reading -- wordbag_regressor.py shows wordbatch.extractors.WordBag, and online feature extraction and parallel FTRL training -- wordseq_regressor.py shows wordbatch.extractors.WordSeq, and training a 1D-convnet regression model -- wordvec_regressor.py shows wordbatch.extractors.WordVec, and combining word vector embeddings for FTRL training -- classify_airline_sentiment.py show how to combine predictions from the four scripts using a Random Forest Regressor on the airline sentiment data +The directory /scripts/ contains scripts for demonstrating and testing basic uses of the toolkit. To run the scripts +one should first install the dependencies: Keras, NLTK, TextBlob, Pandas, Ray, Dask Distributed and PySpark. +The scripts also use the TripAdvisor dataset (http://times.cs.uiuc.edu/~wang296/Data/), and the +precomputed word embeddings glove.twitter.27B.100d and glove.6B.50d (http://nlp.stanford.edu/projects/glove/). Test +data from Crowdflower Open data & Kaggle is provided in the /data directory. -Spark integration -================= -Starting from 1.2, Wordbatch has full Spark integration. All processing steps will be parallelized by Spark, simply by setting wb.use_sc=True and providing data in the RDD format produced by wb.lists2rddbatches(texts). +Airline Classification Example +------------------------------ +classify_airline_sentiment.py shows training and combining predictions with four classifier scripts that use the +Wordbatch extractors and models: wordhash_regressor.py, wordbag_regressor.py, wordseq_regressor.py and +wordvec_regressor.py. The header part of the script can be modified to choose the backend. By default Ray is used and +passed to the other scripts. -A basic script using this is wordbag_regressor_spark.py, which is the wordbag_regressor.py script modified to run on Spark. This converts each minibatch of training data into an RDD, does feature extraction on the RDD, and collects the resulting features for local FTRL model training. A more practical script should read the data from parallelized storage, and implement model training on the RDD as well. +Backends Benchmark Example +-------------------------- +backends_benchmark.py shows how to benchmark different backends on two simple pipeline tasks: +using ApplyBatch with Scikit-learn HashingVectorizer, and running WordBatch Pipeline with most of its possible +processing steps. Dask and Spark are commented out by default, as these need command-line configuration. +All three distributed backends can be configured to run across a distributed cluster, as done in the +commented-out code. -Parallel prediction is also demonstrated in wordbag_regressor_spark.py. By calling the class with predict_parallel(), it will parallelize prediction either locally or on Spark, depending on whether a SparkContext has been set for the class. Contributors ============ Antti Puurula + Anders Topper + Cheng-Tsung Liu diff --git a/scripts/backends_benchmark.py b/scripts/backends_benchmark.py new file mode 100644 index 0000000..46de917 --- /dev/null +++ b/scripts/backends_benchmark.py @@ -0,0 +1,116 @@ +import re +from contextlib import closing, contextmanager +import time +from wordbatch.pipelines import WordBatch, Apply, ApplyBatch +from wordbatch.extractors import WordHash, WordBag +from wordbatch.transformers import Tokenizer, Dictionary +from wordbatch.batcher import Batcher +import os +import json +from sklearn.feature_extraction.text import HashingVectorizer +import warnings +import pandas as pd + +tripadvisor_dir= "../data/tripadvisor/json" + +#Configure below to allow Dask / Spark +# scheduler_ip= "169.254.93.14" +# from dask.distributed import Client +# #dask-scheduler --host 169.254.93.14 +# #dask-worker 169.254.93.14:8786 --nprocs 16 +# dask_client = Client(scheduler_ip+":8786") +# +# from pyspark import SparkContext, SparkConf +# # conf= SparkConf().setAll([('spark.executor.memory', '4g'), ('spark.driver.memory', '30g'), +# # ('spark.driver.maxResultSize', '10g')]) +# import os +# os.environ['PYSPARK_PYTHON'] = '/home/USERNAME/anaconda3/envs/ENV_NAME/bin/python' +# conf= SparkConf().setAll([('spark.executor.memory', '4g'), ('spark.driver.memory', '30g'), +# ('spark.driver.maxResultSize', '10g')]).setMaster("spark://169.254.93.14:7077") +# spark_context = SparkContext(conf=conf) + +import ray +#ray start --head --node-ip-address 169.254.93.14 +#ray.init(redis_address=scheduler_ip+":57113") #Change port accordingly +ray.init() + +@contextmanager +def timer(name): + t0 = time.time() + yield + print(name + " done in " + str(time.time() - t0) + "s") + +if 1==1: + texts= [] + for jsonfile in os.listdir(tripadvisor_dir): + with open(tripadvisor_dir + "/" + jsonfile, 'r') as inputfile: + for line in inputfile: + try: + line = json.loads(line.strip()) + except: + continue + for review in line["Reviews"]: + texts.append(review["Content"]) +# pd.to_pickle(texts, "tripadvisor_data.pkl") +# else: +# texts= pd.read_pickle("tripadvisor_data.pkl") + +non_alphanums = re.compile('[\W+]') +nums_re= re.compile("\W*[0-9]+\W*") +triples_re= re.compile(r"(\w)\1{2,}") +trash_re= [re.compile("<[^>]*>"), re.compile("[^a-z0-9' -]+"), re.compile(" [.0-9'-]+ "), re.compile("[-']{2,}"), + re.compile(" '"),re.compile(" +")] +from nltk.stem.porter import PorterStemmer +stemmer= PorterStemmer() + +def normalize_text(text): + text= text.lower() + text= nums_re.sub(" NUM ", text) + text= " ".join([word for word in non_alphanums.sub(" ",text).strip().split() if len(word)>1]) + return text + +print(len(texts)) +backends= [ + ['serial', ""], + ['multiprocessing', ""], + ['loky', ""], + #['dask', dask_client], #Uncomment once configured + #['spark', spark_context], #Uncomment once configured + ['ray', ray] +] + +tasks= [ + "ApplyBatch", + "WordBag", +] + +data_sizes= [40000, 80000, 160000, 320000, 640000, 1280000] + +for task in tasks: + for data_size in data_sizes: + texts_chunk = texts[:data_size] + print("Task:", task, "Data size:", data_size) + for backend in backends: + batcher = Batcher(procs=16, minibatch_size=5000, backend=backend[0], backend_handle=backend[1]) + #try: + with timer("Completed: ["+task+","+str(len(texts_chunk))+","+backend[0]+"]"), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if task=="ApplyBatch": + hv = HashingVectorizer(decode_error='ignore', n_features=2 ** 25, preprocessor=normalize_text, + ngram_range=(1, 2), norm='l2') + t= ApplyBatch(hv.transform, batcher=batcher).transform(texts_chunk) + print(t.shape, t.data[:5]) + + if task=="WordBag": + wb = WordBatch(normalize_text=normalize_text, + dictionary=Dictionary(min_df=10, max_words=1000000, verbose=0), + tokenizer= Tokenizer(spellcor_count=2, spellcor_dist=2, raw_min_df= 2, + stemmer= stemmer), + extractor=WordBag(hash_ngrams=0, norm= 'l2', tf= 'binary', idf= 50.0), + batcher= batcher, + verbose= 0) + t = wb.fit_transform(texts_chunk) + print(t.shape, t.data[:5]) + # except: + # print("Failed ["+task+","+str(len(texts_chunk))+","+backend[0]+"]") + print("") \ No newline at end of file diff --git a/scripts/classify_airline_sentiment.py b/scripts/classify_airline_sentiment.py index 3ec7481..0a6ad52 100644 --- a/scripts/classify_airline_sentiment.py +++ b/scripts/classify_airline_sentiment.py @@ -9,92 +9,111 @@ from math import * import time, datetime import multiprocessing - -tripadvisor_dir= "../../../data/tripadvisor/json" +from wordbatch.batcher import Batcher +import wordbatch + +print("Wordbatch version:", wordbatch.__version__) +pd.set_option('display.max_rows', 500) +pd.set_option('display.max_columns', 500) +pd.set_option('display.max_colwidth', 1000) + +backend= "ray" +#backend= "multiprocessing" +minibatch_size= 10000 +if backend == "ray": + import ray + ray.init() + b = Batcher(backend="ray", backend_handle=ray, minibatch_size=minibatch_size) +if backend == "multiprocessing": + b = Batcher(backend="multiprocessing", minibatch_size=minibatch_size) + +tripadvisor_dir= "../data/tripadvisor/json" if __name__ == "__main__": - start_time= time.time() - print(datetime.datetime.now()) - - df= pd.DataFrame.from_csv("../data/Tweets.csv", encoding="utf8") - def sentiment_to_label(sentiment): - if sentiment=="neutral": return 0 - if sentiment=="negative": return -1 - return 1 - - df['airline_sentiment_confidence']= df['airline_sentiment_confidence'].astype('str') - df['sentiment']= (df['airline_sentiment']).apply(lambda x: sentiment_to_label(x)) - df= df[['text','sentiment']] - - re_attags= re.compile(" @[^ ]* ") - re_spaces= re.compile("\w+]") - df['text']= df['text'].apply(lambda x: re_spaces.sub(" ",re_attags.sub(" ", " "+x+" "))[1:-1]) - df= df.drop_duplicates(subset=['text']) - df.index= df['id']= range(df.shape[0]) - - non_alphanums=re.compile('[^A-Za-z]+') - def normalize_text(text): return non_alphanums.sub(' ', text).lower().strip() - df['text_normalized']= df['text'].map(lambda x: normalize_text(x)) - df['textblob_score']= df['text_normalized'].map(lambda x: textblob.TextBlob(x).polarity) - - import wordbag_regressor - print("Train wordbag regressor") - wb_regressor= wordbag_regressor.WordbagRegressor("../models/wordbag_model.pkl.gz", tripadvisor_dir) - #wb_regressor= wordbag_regressor.WordbagRegressor("../models/wordbag_model.pkl.gz") - df['wordbag_score']= wb_regressor.predict(df['text'].values) - print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) - - import wordhash_regressor - print("Train wordhash regressor") - wh_regressor= wordhash_regressor.WordhashRegressor("../models/wordhash_model.pkl.gz", tripadvisor_dir) - #wh_regressor= wordhash_regressor.WordhashRegressor("../models/wordhash_model.pkl.gz") - df['wordhash_score']= wh_regressor.predict(df['text'].values) - print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) - - import wordseq_regressor - print("Train wordseq regressor") - ws_regressor = wordseq_regressor.WordseqRegressor("../models/wordseq_model.pkl.gz", tripadvisor_dir) - #ws_regressor = wordseq_regressor.WordseqRegressor("../models/wordseq_model.pkl.gz") - df['wordseq_score']= ws_regressor.predict_batch(df['text'].values) - print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) - - import wordvec_regressor - print("Train wordvec regressor") - wv_regressor= wordvec_regressor.WordvecRegressor("../models/wordvec_model.pkl.gz", tripadvisor_dir) - #wv_regressor= wordvec_regressor.WordvecRegressor("../models/wordvec_model.pkl.gz") - df['wordvec_score'] = wv_regressor.predict(df['text'].values) - print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) - - df['tweet_len']= df['text'].map(lambda x: log(1+len(x))) - df['tweet_wordcount']= df['text'].map(lambda x: log(1+len(x.split()))) - - print(df) - full_preds= np.zeros(df.shape[0]) - columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] #Mean Squared Error: 0.28730889581 - #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordhash_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] # - #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] # - #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordvec_score', 'textblob_score'] # - #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordseq_score', 'textblob_score'] # - - kf= KFold(n_splits=10, shuffle=True, random_state=0) - for train_index, dev_index in kf.split(range(df.shape[0])): - df_train= df.ix[train_index] - df_dev= df.ix[dev_index] - clf= RandomForestRegressor(n_estimators=200, criterion='mse', max_depth=None, min_samples_split=5, - min_samples_leaf=2, min_weight_fraction_leaf=0.0, max_features='auto', - max_leaf_nodes=None, bootstrap=True, oob_score=False, - n_jobs=multiprocessing.cpu_count(), random_state=0, - verbose=0, warm_start=False) - - clf.fit(df_train[columns_pick], df_train['sentiment']) - preds= clf.predict(df_dev[columns_pick]) - for x in range(len(preds)): full_preds[df_dev['id'].iloc[x]]= preds[x] - - df['preds']= full_preds - df['preds']= sp.clip(full_preds, -1.0, 1.0) - - print(datetime.datetime.now()) - print(("%s minutes ---" % round(((time.time() - start_time)/60),2))) - - c_mse= sklearn.metrics.mean_squared_error(df['sentiment'], df['preds'], sample_weight=None, - multioutput='uniform_average') - print("Mean Squared Error:", c_mse) + start_time= time.time() + print(datetime.datetime.now()) + + #df= pd.DataFrame.from_csv("../data/Tweets.csv", encoding="utf8") + df = pd.read_csv("../data/Tweets.csv", encoding="utf8") + def sentiment_to_label(sentiment): + if sentiment=="neutral": return 0 + if sentiment=="negative": return -1 + return 1 + d_sentiment_to_label= {"neutral":0, "negative":-1, "positive":1} + df['airline_sentiment_confidence']= df['airline_sentiment_confidence'].astype('str') + df['sentiment'] = (df['airline_sentiment']).map(d_sentiment_to_label) + #df['sentiment']= (df['airline_sentiment']).apply(lambda x: sentiment_to_label(x)) + df= df[['text','sentiment']] + + re_attags= re.compile(" @[^ ]* ") + re_spaces= re.compile("\w+]") + df['text']= df['text'].apply(lambda x: re_spaces.sub(" ",re_attags.sub(" ", " "+x+" "))[1:-1]) + df= df.drop_duplicates(subset=['text']) + df.index= df['id']= range(df.shape[0]) + + non_alphanums=re.compile('[^A-Za-z]+') + def normalize_text(text): return non_alphanums.sub(' ', text).lower().strip() + df['text_normalized']= df['text'].map(lambda x: normalize_text(x)) + df['textblob_score']= df['text_normalized'].map(lambda x: textblob.TextBlob(x).polarity) + + import wordbag_regressor + print("Train wordbag regressor") + wb_regressor= wordbag_regressor.WordbagRegressor("../models/wordbag_model.pkl.gz", tripadvisor_dir, b) + #wb_regressor= wordbag_regressor.WordbagRegressor("../models/wordbag_model.pkl.gz") + df['wordbag_score']= wb_regressor.predict(df['text'].values) + print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) + + import wordhash_regressor + print("Train wordhash regressor") + wh_regressor= wordhash_regressor.WordhashRegressor("../models/wordhash_model.pkl.gz", tripadvisor_dir, b) + #wh_regressor= wordhash_regressor.WordhashRegressor("../models/wordhash_model.pkl.gz") + df['wordhash_score']= wh_regressor.predict(df['text'].values) + print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) + + import wordseq_regressor + print("Train wordseq regressor") + ws_regressor = wordseq_regressor.WordseqRegressor("../models/wordseq_model.pkl.gz", tripadvisor_dir, b) + #ws_regressor = wordseq_regressor.WordseqRegressor("../models/wordseq_model.pkl.gz") + df['wordseq_score']= ws_regressor.predict_batch(df['text'].values) + print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) + + import wordvec_regressor + print("Train wordvec regressor") + wv_regressor= wordvec_regressor.WordvecRegressor("../models/wordvec_model.pkl.gz", tripadvisor_dir, b) + #wv_regressor= wordvec_regressor.WordvecRegressor("../models/wordvec_model.pkl.gz") + df['wordvec_score'] = wv_regressor.predict(df['text'].values) + print(df['wordvec_score']) + print(("%s minutes ---" % round(((time.time() - start_time) / 60), 2))) + + df['tweet_len']= df['text'].map(lambda x: log(1+len(x))) + df['tweet_wordcount']= df['text'].map(lambda x: log(1+len(x.split()))) + + print(df) + full_preds= np.zeros(df.shape[0]) + columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] #Mean Squared Error: 0.28730889581 + #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordhash_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] # + #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordseq_score', 'wordvec_score', 'textblob_score'] # + #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordvec_score', 'textblob_score'] # + #columns_pick= ['tweet_len', 'tweet_wordcount', 'wordbag_score', 'wordhash_score', 'wordseq_score', 'textblob_score'] # + + kf= KFold(n_splits=10, shuffle=True, random_state=0) + for train_index, dev_index in kf.split(range(df.shape[0])): + df_train= df.iloc[train_index] + df_dev= df.iloc[dev_index] + clf= RandomForestRegressor(n_estimators=200, criterion='mse', max_depth=None, min_samples_split=5, + min_samples_leaf=2, min_weight_fraction_leaf=0.0, max_features='auto', + max_leaf_nodes=None, bootstrap=True, oob_score=False, + n_jobs=multiprocessing.cpu_count(), random_state=0, + verbose=0, warm_start=False) + + clf.fit(df_train[columns_pick], df_train['sentiment']) + preds= clf.predict(df_dev[columns_pick]) + for x in range(len(preds)): full_preds[df_dev['id'].iloc[x]]= preds[x] + + df['preds']= sp.clip(full_preds, -1.0, 1.0) + + print(datetime.datetime.now()) + print(("%s minutes ---" % round(((time.time() - start_time)/60),2))) + + c_mse= sklearn.metrics.mean_squared_error(df['sentiment'], df['preds'], sample_weight=None, + multioutput='uniform_average') + print("Mean Squared Error:", c_mse) diff --git a/scripts/wordbag_regressor.py b/scripts/wordbag_regressor.py index a1e6932..3660824 100644 --- a/scripts/wordbag_regressor.py +++ b/scripts/wordbag_regressor.py @@ -6,85 +6,90 @@ import re import json import gzip -import wordbatch +from wordbatch.pipelines import WordBatch from wordbatch.models import FTRL from wordbatch.extractors import WordBag import threading import sys if sys.version_info.major == 3: - import pickle as pkl + import pickle as pkl else: - import cPickle as pkl + import cPickle as pkl non_alphanums = re.compile('[\W+]') nums_re= re.compile("\W*[0-9]+\W*") triples_re= re.compile(r"(\w)\1{2,}") trash_re= [re.compile("<[^>]*>"), re.compile("[^a-z0-9' -]+"), re.compile(" [.0-9'-]+ "), re.compile("[-']{2,}"), - re.compile(" '"),re.compile(" +")] + re.compile(" '"),re.compile(" +")] from nltk.stem.porter import PorterStemmer stemmer= PorterStemmer() def normalize_text(text): - text= text.lower() - text= nums_re.sub(" NUM ", text) - text= " ".join([word for word in non_alphanums.sub(" ",text).strip().split() if len(word)>1]) - return text + text= text.lower() + text= nums_re.sub(" NUM ", text) + text= " ".join([word for word in non_alphanums.sub(" ",text).strip().split() if len(word)>1]) + return text class WordbagRegressor(object): - def __init__(self, pickle_model="", datadir=None): - self.wb = wordbatch.WordBatch(normalize_text, extractor=(WordBag, {"hash_ngrams":3, - "hash_ngrams_weights":[-1.0, -1.0, 1.0],"hash_size":2**23, "norm":'l2', "tf":'binary', "idf":50.0}) ) - self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=2 ** 23, iters=1, inv_link="identity") - if datadir==None: (self.wb, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) - else: self.train(datadir, pickle_model) + def __init__(self, pickle_model="", datadir=None, batcher= None): + self.wb = WordBatch(normalize_text=normalize_text, + extractor=WordBag(hash_ngrams=3, hash_ngrams_weights=[-1.0, -1.0, 1.0], + hash_size=2**23, norm='l2', tf='binary', idf=50.0), + batcher= batcher) - def fit_batch(self, texts, labels, rcount): - texts, labels= self.wb.batcher.shuffle_batch(texts, labels, rcount) - print("Transforming", rcount) - texts= self.wb.fit_transform(texts, reset= False) - print("Training", rcount) - self.clf.fit(texts, labels, reset= False) + self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=2 ** 23, iters=1, inv_link="identity") + if datadir==None: (self.wb, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) + else: self.train(datadir, pickle_model) - def train(self, datadir, pickle_model=""): - texts= [] - labels= [] - training_data= os.listdir(datadir) - rcount= 0 - batchsize= 100000 + def fit_batch(self, texts, labels, rcount): + texts, labels= self.wb.batcher.shuffle_batch(texts, labels, rcount) + print("Transforming", rcount) + texts= self.wb.fit_transform(texts, reset= False) + print("Training", rcount) + self.clf.fit(texts, labels, reset= False) - p = None - for jsonfile in training_data: - with open(datadir + "/" + jsonfile, 'r') as inputfile: - for line in inputfile: - #if rcount > 1000000: break - try: line = json.loads(line.strip()) - except: continue - for review in line["Reviews"]: - rcount+= 1 - if rcount % 100000 == 0: print(rcount) - if rcount % 7 != 0: continue - if "Overall" not in review["Ratings"]: continue - texts.append(review["Content"]) - labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5) - if len(texts) % batchsize == 0: - if p != None: p.join() - p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) - p.start() - texts= [] - labels= [] - if p != None: p.join() - self.fit_batch(texts, labels, rcount) + def train(self, datadir, pickle_model=""): + texts= [] + labels= [] + training_data= os.listdir(datadir) + rcount= 0 + batchsize= 100000 - self.wb.dictionary_freeze= True + p = None + for jsonfile in training_data: + with open(datadir + "/" + jsonfile, 'r') as inputfile: + for line in inputfile: + #if rcount > 1000000: break + try: line = json.loads(line.strip()) + except: continue + for review in line["Reviews"]: + rcount+= 1 + if rcount % 100000 == 0: print(rcount) + if rcount % 7 != 0: continue + if "Overall" not in review["Ratings"]: continue + texts.append(review["Content"]) + labels.append((float(review["Ratings"]["Overall"]) - 3) * 0.5) + if len(texts) % batchsize == 0: + if p != None: p.join() + p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) + p.start() + texts= [] + labels= [] + if p != None: p.join() + self.fit_batch(texts, labels, rcount) - if pickle_model!="": - with gzip.open(pickle_model, 'wb') as model_file: - pkl.dump((self.wb, self.clf), model_file, protocol=2) + self.wb.dictionary_freeze= True - def predict(self, texts): - counts= self.wb.transform(texts) - return self.clf.predict(counts) + if pickle_model!="": + with gzip.open(pickle_model, 'wb') as model_file: + backend= self.wb.batcher.backend + backend_handle= self.wb.batcher.backend_handle + self.wb.batcher.backend= "serial" + self.wb.batcher.backend_handle = None + pkl.dump((self.wb, self.clf), model_file, protocol=2) + self.wb.batcher.backend = backend + self.wb.batcher.backend_handle = backend_handle - def predict_parallel(self, texts): - counts= self.wb.transform(texts) - return self.wb.predict_parallel(counts, self.clf) \ No newline at end of file + def predict(self, texts): + counts= self.wb.transform(texts) + return self.clf.predict(counts) \ No newline at end of file diff --git a/scripts/wordbag_regressor_spark.py b/scripts/wordbag_regressor_spark.py index 9d87ff6..3913488 100644 --- a/scripts/wordbag_regressor_spark.py +++ b/scripts/wordbag_regressor_spark.py @@ -10,96 +10,128 @@ from wordbatch.models import FTRL from wordbatch.extractors import WordBag import threading +import pandas as pd import sys if sys.version_info.major == 3: - import pickle as pkl + import pickle as pkl else: - import cPickle as pkl + import cPickle as pkl non_alphanums = re.compile('[\W]') nums_re= re.compile("\W*[0-9]+\W*") triples_re= re.compile(r"(\w)\1{2,}") trash_re= [re.compile("<[^>]*>"), re.compile("[^a-z0-9' -]+"), re.compile(" [.0-9'-]+ "), re.compile("[-']{2,}"), - re.compile(" '"),re.compile(" +")] + re.compile(" '"),re.compile(" +")] from nltk.stem.porter import PorterStemmer stemmer= PorterStemmer() def normalize_text(text): - text= text.lower() - text= nums_re.sub(" NUM ", text) - text= " ".join([word for word in non_alphanums.sub(" ",text).split() if len(word)>1]) - return text + text= text.lower() + text= nums_re.sub(" NUM ", text) + text= " ".join([word for word in non_alphanums.sub(" ",text).split() if len(word)>1]) + return text class WordbagRegressor(object): - def __init__(self, pickle_model="", datadir=None): - from pyspark import SparkContext - self.sc= SparkContext() - self.wordbatch = wordbatch.WordBatch(normalize_text, extractor=(WordBag, {"hash_ngrams":3, - "hash_ngrams_weights":[-1.0, -1.0, 1.0],"hash_size":2**23, "norm":'l2', "tf":'binary', "idf":50.0})) - self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=2 ** 25, iters=1, inv_link="identity") - self.wordbatch.use_sc= True - if datadir==None: (self.wordbatch, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) - else: self.train(datadir, pickle_model) + def __init__(self, pickle_model="", datadir=None): + from pyspark import SparkContext + self.sc= SparkContext() + self.wordbatch = wordbatch.WordBatch(normalize_text, backend="spark", backend_handle=self.sc, + extractor=(WordBag, {"hash_ngrams":3, + "hash_ngrams_weights":[-1.0, -1.0, 1.0], + "hash_size":2**23, "norm":'l2', + "tf":'binary', "idf":50.0})) + self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=2 ** 23, iters=1, inv_link="identity") + if datadir==None: (self.wordbatch, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) + else: self.train(datadir, pickle_model) - def fit_batch(self, texts, labels, rcount): - texts, labels= self.wordbatch.shuffle_batch(texts, labels, rcount) - print("Transforming", rcount) - if self.sc != None: - data_rdd= self.wordbatch.lists2rddbatches([texts, labels], self.sc) - data_rdd= self.wordbatch.transform(data_rdd) - [texts, labels]= self.wordbatch.rddbatches2lists(data_rdd) - else: - texts= self.wordbatch.transform(texts) - print("Training", rcount) - self.clf.fit(texts, labels) + def fit_batch(self, texts, labels, rcount): + print("Transforming", rcount) + # if self.sc != None: + # data_rdd= self.wordbatch.lists2rddbatches([texts, labels], self.sc) + # data_rdd= self.wordbatch.transform(data_rdd) + # [texts, labels]= self.wordbatch.rddbatches2lists(data_rdd) + # else: + # print(texts[:2]) + # print(pd.Series(labels).value_counts()) + texts= self.wordbatch.partial_fit_transform(texts) + print("Training", rcount) + self.clf.partial_fit(texts, labels) - def train(self, datadir, pickle_model=""): - texts= [] - labels= [] - training_data= os.listdir(datadir) - rcount= 0 - batchsize= 20000 + def train(self, datadir, pickle_model=""): + texts= [] + labels= [] + training_data= os.listdir(datadir) + rcount= 0 + batchsize= 20000 - p = None - for jsonfile in training_data: - with open(datadir + "/" + jsonfile, 'r') as inputfile: - for line in inputfile: - #if rcount > 1000000: break - try: line = json.loads(line.strip()) - except: continue - for review in line["Reviews"]: - rcount+= 1 - if rcount % 100000 == 0: print(rcount) - if rcount % 7 != 0: continue - if "Overall" not in review["Ratings"]: continue - texts.append(review["Content"]) - labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5) - if len(texts) % batchsize == 0: - if p != None: p.join() - p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) - p.start() - texts= [] - labels= [] - if p != None: p.join() - self.fit_batch(texts, labels, rcount) + p = None + for jsonfile in training_data: + with open(datadir + "/" + jsonfile, 'r') as inputfile: + for line in inputfile: + #if rcount > 1000000: break + try: line = json.loads(line.strip()) + except: continue + for review in line["Reviews"]: + rcount+= 1 + if rcount % 100000 == 0: print(rcount) + if rcount % 7 != 0: continue + if "Overall" not in review["Ratings"]: continue + texts.append(review["Content"]) + labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5) + if len(texts) % batchsize == 0: + if p != None: p.join() + p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) + p.start() + texts= [] + labels= [] + if p != None: p.join() + self.fit_batch(texts, labels, rcount) - self.wordbatch.dictionary_freeze= True + self.wordbatch.dictionary_freeze= True - if pickle_model!="": - with gzip.open(pickle_model, 'wb') as model_file: - pkl.dump((self.wordbatch, self.clf), model_file, protocol=2) + if pickle_model!="": + with gzip.open(pickle_model, 'wb') as model_file: + pkl.dump((self.wordbatch, self.clf), model_file, protocol=2) - def predict(self, texts): - if self.sc != None: - data_rdd= self.wordbatch.lists2rddbatches([texts, []], self.sc) - data_rdd= self.wordbatch.transform(data_rdd) - [counts, labels]= self.wordbatch.rddbatches2lists(data_rdd) - else: counts= self.wordbatch.transform(texts) - return self.clf.predict(counts) + def predict(self, texts): + # if self.sc != None: + # data_rdd= self.wordbatch.lists2rddbatches([texts, []], self.sc) + # data_rdd= self.wordbatch.transform(data_rdd) + # [counts, labels]= self.wordbatch.rddbatches2lists(data_rdd) + # else: + counts= self.wordbatch.transform(texts) + return self.clf.predict(counts) - def predict_parallel(self, texts): - if self.sc != None: - data_rdd= self.wordbatch.lists2rddbatches([texts, []], self.sc) - counts_rdd= self.wordbatch.transform(data_rdd) - return self.wordbatch.rddbatches2lists(self.wordbatch.predict_parallel(counts_rdd, self.clf))[0] - counts= self.wordbatch.transform(texts) - return self.wordbatch.predict_parallel(counts, self.clf) \ No newline at end of file + def predict_parallel(self, texts): + # if self.sc != None: + # data_rdd= self.wordbatch.lists2rddbatches([texts, []] , self.sc) + # counts_rdd= self.wordbatch.transform(data_rdd) + # return self.wordbatch.rddbatches2lists(self.wordbatch.predict_parallel(counts_rdd, self.clf))[0] + counts= self.wordbatch.transform(texts) + return self.wordbatch.predict_parallel(counts, self.clf) + +if __name__ == "__main__": + df= pd.DataFrame.from_csv("../data/Tweets.csv", encoding="utf8") + def sentiment_to_label(sentiment): + if sentiment=="neutral": return 0 + if sentiment=="negative": return -1 + return 1 + + df['airline_sentiment_confidence']= df['airline_sentiment_confidence'].astype('str') + df['sentiment']= (df['airline_sentiment']).apply(lambda x: sentiment_to_label(x)) + df= df[['text','sentiment']] + + re_attags= re.compile(" @[^ ]* ") + re_spaces= re.compile("\w+]") + df['text']= df['text'].apply(lambda x: re_spaces.sub(" ",re_attags.sub(" ", " "+x+" "))[1:-1]) + df= df.drop_duplicates(subset=['text']) + df.index= df['id']= range(df.shape[0]) + + non_alphanums=re.compile('[^A-Za-z]+') + def normalize_text(text): return non_alphanums.sub(' ', text).lower().strip() + df['text_normalized']= df['text'].map(lambda x: normalize_text(x)) + + import wordbag_regressor + print("Train wordbag regressor") + wb_regressor= WordbagRegressor("", "../../../data/tripadvisor/json") + df['wordbag_score']= wb_regressor.predict(df['text'].values) + print(df['wordbag_score'].value_counts()) \ No newline at end of file diff --git a/scripts/wordhash_regressor.py b/scripts/wordhash_regressor.py index 6deec97..58b7b49 100644 --- a/scripts/wordhash_regressor.py +++ b/scripts/wordhash_regressor.py @@ -7,9 +7,10 @@ import json import gzip import scipy.sparse as ssp -import wordbatch +from wordbatch.pipelines import WordBatch from wordbatch.extractors import WordHash from wordbatch.models import FM_FTRL +from wordbatch.transformers import Tokenizer import threading import multiprocessing import sys @@ -36,10 +37,9 @@ def __init__(self): self.texts= None class WordhashRegressor(object): - def __init__(self, pickle_model="", datadir=None): - self.wb= wordbatch.WordBatch(normalize_text, stemmer= stemmer, - extractor=(WordHash, {"decode_error":'ignore', "n_features":2 ** 25, - "non_negative":False, "ngram_range":(1,2), "norm":'l2'})) + def __init__(self, pickle_model="", datadir=None, batcher= None): + self.wb= WordBatch(normalize_text, tokenizer= Tokenizer(stemmer=stemmer), extractor=WordHash( + decode_error='ignore', n_features=2 ** 25, ngram_range=(1,2), norm='l2'), batcher= batcher) self.clf = FM_FTRL(D=2 ** 25, D_fm= 4, iters=1, inv_link="identity", threads= multiprocessing.cpu_count()//2) if datadir==None: (self.wb, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) else: self.train(datadir, pickle_model) @@ -90,7 +90,13 @@ def train(self, datadir, pickle_model=""): self.clf.fit(texts, labels) if pickle_model != "": with gzip.open(pickle_model, 'wb') as model_file: + backend= self.wb.batcher.backend + backend_handle= self.wb.batcher.backend_handle + self.wb.batcher.backend= "serial" + self.wb.batcher.backend_handle = None pkl.dump((self.wb, self.clf), model_file, protocol=2) + self.wb.batcher.backend = backend + self.wb.batcher.backend_handle = backend_handle def predict(self, texts): counts= self.wb.transform(texts) diff --git a/scripts/wordseq_regressor.py b/scripts/wordseq_regressor.py index 3b721e5..16a902a 100644 --- a/scripts/wordseq_regressor.py +++ b/scripts/wordseq_regressor.py @@ -8,10 +8,12 @@ import os import json import scipy as sp +import numpy as np from keras.layers import * from keras.models import Sequential -import wordbatch +from wordbatch.pipelines import WordBatch from wordbatch.extractors import WordSeq +from wordbatch.transformers import Dictionary import random import threading from keras.models import load_model @@ -34,7 +36,7 @@ def __init__(self): self.labels= None class WordseqRegressor(): - def __init__(self, pickle_model="", datadir=None): + def __init__(self, pickle_model="", datadir=None, batcher= None): seed = 10002 session_conf = tf.ConfigProto(intra_op_parallelism_threads=multiprocessing.cpu_count()//2, inter_op_parallelism_threads=1) @@ -46,11 +48,10 @@ def __init__(self, pickle_model="", datadir=None): self.maxlen = 200 self.max_words = 20000 - self.wb= wordbatch.WordBatch(normalize_text, max_words=self.max_words, - extractor=(WordSeq, {"seq_maxlen": self.maxlen})) + self.wb= WordBatch(normalize_text, dictionary=Dictionary(max_words=self.max_words), + extractor=WordSeq(seq_maxlen=self.maxlen), batcher=batcher) self.model = Sequential() self.model.add(Embedding(self.max_words+2, 20, input_length=self.maxlen)) - self.model.add(Conv1D(activation="relu", padding="same", strides=1, filters=10, kernel_size=3)) self.model.add(Dropout(0.5)) self.model.add(BatchNormalization()) @@ -108,7 +109,13 @@ def train(self, datadir, pickle_model=""): self.model.fit(train[0], train[1], batch_size=2048, epochs=2, validation_data=(test[0], test[1])) if pickle_model != "": self.model.save(pickle_model) + backend = self.wb.batcher.backend + backend_handle = self.wb.batcher.backend_handle + self.wb.batcher.backend = "serial" + self.wb.batcher.backend_handle = None with gzip.open(pickle_model + ".wb", 'wb') as model_file: pkl.dump(self.wb, model_file, protocol=2) + self.wb.batcher.backend = backend + self.wb.batcher.backend_handle = backend_handle def predict_batch(self, texts): results= [x[0] for x in self.model.predict(np.array(self.wb.transform(texts)))] diff --git a/scripts/wordvec_regressor.py b/scripts/wordvec_regressor.py index bd75e9f..ca37269 100644 --- a/scripts/wordvec_regressor.py +++ b/scripts/wordvec_regressor.py @@ -6,85 +6,100 @@ import re import json import gzip -import wordbatch +from wordbatch.pipelines import WordBatch from wordbatch.models import FTRL from wordbatch.extractors import WordVec, Hstack from wordbatch.data_utils import shuffle import threading import sys if sys.version_info.major == 3: - import pickle as pkl + import pickle as pkl else: - import cPickle as pkl + import cPickle as pkl non_alphanums = re.compile('[\W+]') nums_re= re.compile("\W*[0-9]+\W*") trash_re= [re.compile("<[^>]*>"), re.compile("[^a-z0-9' -]+"), re.compile(" [.0-9'-]+ "), re.compile("[-']{2,}"), - re.compile(" '"),re.compile(" +")] + re.compile(" '"),re.compile(" +")] def normalize_text(text): - text= text.lower() - text= nums_re.sub(" NUM ", text) - text= " ".join([word for word in non_alphanums.sub(" ",text).strip().split() if len(word)>1]) - return text + text= text.lower() + text= nums_re.sub(" NUM ", text) + text= " ".join([word for word in non_alphanums.sub(" ",text).strip().split() if len(word)>1]) + return text class WordvecRegressor(object): - def __init__(self, pickle_model="", datadir=None): - self.wb= wordbatch.WordBatch(normalize_text, - extractor=(Hstack, - [(WordVec, {"wordvec_file": "../../../data/word2vec/glove.twitter.27B.100d.txt.gz", - "normalize_text": normalize_text, "encoding": "utf8"}), - (WordVec, {"wordvec_file": "../../../data/word2vec/glove.6B.50d.txt.gz", - "normalize_text": normalize_text, "encoding": "utf8"})])) + def __init__(self, pickle_model="", datadir=None, batcher=None): + self.wb= WordBatch(normalize_text, extractor=Hstack([ + WordVec(wordvec_file="../../../data/word2vec/glove.twitter.27B.100d.txt.gz", normalize_text=normalize_text, + encoding="utf8"), + WordVec(wordvec_file="../../../data/word2vec/glove.6B.50d.txt.gz", normalize_text=normalize_text, + encoding="utf8")])) + # from wordbatch.pipelines import FeatureUnion + # from wordbatch.transformers import Dictionary, TextNormalizer + # from sklearn.pipeline import Pipeline + # tn= TextNormalizer(normalize_text=normalize_text) + # dct= Dictionary() + # vec1= WordVec(wordvec_file="../../../data/word2vec/glove.twitter.27B.100d.txt.gz", + # normalize_text=normalize_text, encoding="utf8", dictionary= dct) + # vec2= WordVec(wordvec_file="../../../data/word2vec/glove.6B.50d.txt.gz", + # normalize_text=normalize_text, encoding="utf8", dictionary= dct) + # self.wb = Pipeline(steps= [("tn", tn), ("dct", dct), ("vecs", FeatureUnion([("vec1", vec1), ("vec2", vec2)]))]) + self.batcher= batcher - self.wb.dictionary_freeze= True + self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=100+50, iters=1, inv_link= "identity") - self.clf= FTRL(alpha=1.0, beta=1.0, L1=0.00001, L2=1.0, D=100+50, iters=1, inv_link= "identity") + if datadir==None: (self.wb, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) + else: self.train(datadir, pickle_model) - if datadir==None: (self.wb, self.clf)= pkl.load(gzip.open(pickle_model, 'rb')) - else: self.train(datadir, pickle_model) + def fit_batch(self, texts, labels, rcount): + texts, labels = shuffle(texts, labels, seed=rcount) + print("Transforming", rcount) + #texts= self.wb.fit_transform(texts, tn__batcher=self.batcher, dct__reset= False, dct__batcher= self.batcher) + texts = self.wb.fit_transform(texts) + print("Training", rcount) + self.clf.fit(texts, labels, reset= False) - def fit_batch(self, texts, labels, rcount): - texts, labels = shuffle(texts, labels, seed=rcount) - print("Transforming", rcount) - texts= self.wb.fit_transform(texts, reset= False) - print("Training", rcount) - self.clf.fit(texts, labels, reset= False) + def train(self, datadir, pickle_model=""): + texts= [] + labels= [] + training_data= os.listdir(datadir) + rcount= 0 + batchsize= 80000 - def train(self, datadir, pickle_model=""): - texts= [] - labels= [] - training_data= os.listdir(datadir) - rcount= 0 - batchsize= 80000 + p= None + for jsonfile in training_data: + with open(datadir + "/" + jsonfile, 'r') as inputfile: + for line in inputfile: + #if rcount > 1000000: break + try: line= json.loads(line.strip()) + except: continue + for review in line["Reviews"]: + rcount+= 1 + if rcount % 100000 == 0: print(rcount) + if rcount % 6 != 0: continue + if "Overall" not in review["Ratings"]: continue + texts.append(review["Content"]) + labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5) + if len(texts) % batchsize == 0: + if p != None: p.join() + p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) + p.start() + texts= [] + labels= [] + if p != None: p.join() + self.fit_batch(texts, labels, rcount) - p= None - for jsonfile in training_data: - with open(datadir + "/" + jsonfile, 'r') as inputfile: - for line in inputfile: - #if rcount > 1000000: break - try: line= json.loads(line.strip()) - except: continue - for review in line["Reviews"]: - rcount+= 1 - if rcount % 100000 == 0: print(rcount) - if rcount % 6 != 0: continue - if "Overall" not in review["Ratings"]: continue - texts.append(review["Content"]) - labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5) - if len(texts) % batchsize == 0: - if p != None: p.join() - p= threading.Thread(target=self.fit_batch, args=(texts, labels, rcount)) - p.start() - texts= [] - labels= [] - if p != None: p.join() - self.fit_batch(texts, labels, rcount) + # if pickle_model!="": + # with gzip.open(pickle_model, 'wb') as model_file: + # backend = self.wb.batcher.backend + # backend_handle = self.wb.batcher.backend_handle + # self.wb.batcher.backend = "serial" + # self.wb.batcher.backend_handle = None + # pkl.dump((self.wb, self.clf), model_file, protocol=2) + # self.wb.batcher.backend = backend + # self.wb.batcher.backend_handle = backend_handle - if pickle_model!="": - with gzip.open(pickle_model, 'wb') as model_file: - pkl.dump((self.wb, self.clf), model_file, protocol=2) - - def predict(self, texts): - vecs= self.wb.transform(texts) - return self.clf.predict(vecs) \ No newline at end of file + def predict(self, texts): + vecs= self.wb.transform(texts) + return self.clf.predict(vecs) \ No newline at end of file diff --git a/setup.py b/setup.py index aa9406a..9998558 100644 --- a/setup.py +++ b/setup.py @@ -13,16 +13,17 @@ setup( name='Wordbatch', - version='1.3.8', - description='Parallel text feature extraction for machine learning', + version='1.4.0', + description='Python library for distributed AI processing pipelines, using swappable scheduler backends', url='https://github.com/anttttti/Wordbatch', author='Antti Puurula', author_email='antti.puurula@yahoo.com', packages=['wordbatch', - 'wordbatch.extractors', - 'wordbatch.models', - 'wordbatch.transformers' + 'wordbatch.pipelines', + 'wordbatch.extractors', + 'wordbatch.models', + 'wordbatch.transformers' ], license='GNU GPL 2.0', @@ -30,7 +31,6 @@ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v2 (GPLv2)", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Cython", @@ -38,8 +38,8 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], install_requires=['cython', 'scikit-learn', 'python-Levenshtein', 'py-lz4framed', 'randomgen', 'numpy', 'scipy', - 'pandas', 'dill'], - extras_require={'dev': ['nltk', 'textblob', 'keras', 'pyspark']}, + 'pandas'], + extras_require={'dev': ['nltk', 'textblob', 'keras', 'pyspark', 'dask', 'distributed', 'ray']}, cmdclass= {'build_ext': build_ext}, diff --git a/wordbatch/__init__.py b/wordbatch/__init__.py index 8a1bb66..1efbcba 100644 --- a/wordbatch/__init__.py +++ b/wordbatch/__init__.py @@ -1,5 +1,4 @@ import os PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__)) -__version__ = '1.3.8' +__version__ = '1.4.0' -from .wordbatch import * diff --git a/wordbatch/batcher.py b/wordbatch/batcher.py index 08b65a2..764508e 100644 --- a/wordbatch/batcher.py +++ b/wordbatch/batcher.py @@ -3,84 +3,78 @@ from __future__ import division from __future__ import absolute_import from __future__ import print_function -import dill import multiprocessing -import types from contextlib import closing import scipy.sparse as ssp import random import pandas as pd -import sys from math import ceil + class Batcher(object): """Scheduler to handle parallel jobs on minibatches Parameters ---------- procs: int - Number of process(es)/thread(s) for executing task in parallel + Number of process(es)/thread(s) for executing task in parallel. Used for multiprocessing, threading and Loky minibatch_size: int - Expected size of each minibatch to individually perform task on. + Expected size of each minibatch + + backend: {'serial', 'multiprocessing', 'threading', 'loky', 'spark', 'dask', 'ray'} + Backend for computing the tasks + + - 'serial' sequential execution without a backend scheduler - timeout: int | float - Maximum time(seconds) to wait until the jobs finish or until a timeout occurs. + - 'multiprocessing' Python standard multiprocessing library - spark_context: boolean - Use SparkContext as backend. If set, input data is assumed to be RDD in format generated by lists2rddbatches(). + - 'threading' Python standard threading library - method: {'serial', 'multiprocessing', 'threading'} - Backend method to apply task to the minibatches. Spark is special case, set by "spark_context" + - 'loky' Loky fork of multiprocessing library - - 'serial' will apply task sequentially to minibatches without additional parallelism. + - 'spark' PySpark local or distributed execution - - 'multiprocessing' will apply task to minibatches in parallel using multiple processes. + - 'dask' Dask Distributed local or distributed execution - - 'threading' will apply task to minibatches in parallel using multiple threads. + - 'ray' Ray local or distributed execution + + backend_handle: object + Backend handle for sending tasks verbose: int Verbosity level. Setting verbose > 0 will display additional information depending on the specific level set. """ - def __init__(self, procs= 0, minibatch_size= 20000, timeout= 600, spark_context= None, method= "multiprocessing", - verbose= 1): + def __init__(self, procs= 0, minibatch_size= 20000, backend_handle= None, backend= "multiprocessing", verbose= 0): if procs==0: procs= multiprocessing.cpu_count() self.procs= procs self.verbose= verbose self.minibatch_size= minibatch_size - self.timeout= timeout - self.spark_context= spark_context - self.method= method + self.backend_handle= backend_handle + self.backend= backend - def lists2rddbatches(self, lists, minibatch_size=0): + def list2indexedrdd(self, lst, minibatch_size=0): if minibatch_size==0: minibatch_size= self.minibatch_size - start= 0; len_data= len(lists[0]); batch_count= 0 + start= 0; len_data= len(lst); batch_count= 0 batches= [] while start < len_data: - batches.append([batch_count]+[column[start:start + minibatch_size] for column in lists]) + batches.append([batch_count]+[lst[start:start + minibatch_size]]) start+= minibatch_size batch_count+= 1 - return self.spark_context.parallelize(batches) + return self.backend_handle.parallelize(batches) - def rddbatches2lists(self, rddbatches, sort= True): - batches= rddbatches.collect() + def indexedrdd2list(self, indexedrdd, sort= True): + batches= indexedrdd.collect() if sort: batches= sorted(batches) - texts= [] - labels= [] - for batch in batches: - texts.append(batch[1]) - labels.append(batch[2]) - texts= self.merge_batches(texts) - labels= self.merge_batches(labels) - return texts, labels + return [batch[1] for batch in batches] - def split_batches(self, data, minibatch_size= None): + def split_batches(self, data, minibatch_size= None, backend= None): """Split data into minibatches with a specified size Parameters ---------- data: iterable and indexable - List-like data to be split into batches. Includes spark_contextipy matrices and Pandas DataFrames. + List-like data to be split into batches. Includes backend_handleipy matrices and Pandas DataFrames. minibatch_size: int Expected sizes of minibatches split from the data. @@ -90,16 +84,28 @@ def split_batches(self, data, minibatch_size= None): data_split: list List of minibatches, each entry is a list-like object representing the data subset in a batch. """ - if minibatch_size==None: minibatch_size= self.minibatch_size - if isinstance(data, list) or isinstance(data, tuple): len_data= len(data) + if minibatch_size is None: minibatch_size= self.minibatch_size + if backend is None: backend= self.backend + if isinstance(data, list) or isinstance(data, tuple) or isinstance(data, dict): len_data= len(data) else: len_data= data.shape[0] + if backend=="spark": return self.list2indexedrdd(data, minibatch_size) if isinstance(data,pd.DataFrame): - data_split = [data.iloc[x * minibatch_size:(x + 1) * minibatch_size] for x in + data= [data.iloc[x * minibatch_size:(x + 1) * minibatch_size] for x in range(int(ceil(len_data / minibatch_size)))] + elif isinstance(data, dict): + data = [dict(list(data.items())[x * minibatch_size:min(len_data, (x + 1) * minibatch_size)]) + for x in range(int(ceil(len_data / minibatch_size)))] else: - data_split= [data[x* minibatch_size:min(len_data, (x+1)*minibatch_size)] - for x in range(int(ceil(len_data/minibatch_size)))] - return data_split + data= [data[x* minibatch_size:min(len_data, (x+1)*minibatch_size)] + for x in range(int(ceil(len_data/minibatch_size)))] + ###if backend=="dask": return self.backend_handle.scatter(data) + return data + + def collect_batches(self, data, backend= None, sort= True): + if backend is None: backend= self.backend + if backend == "spark": data= self.indexedrdd2list(data, sort) + if backend == "dask": data = self.backend_handle.gather(data) + return data def merge_batches(self, data): """Merge a list of data minibatches into one single instance representing the data @@ -118,8 +124,8 @@ def merge_batches(self, data): if isinstance(data[0], pd.DataFrame) or isinstance(data[0], pd.Series): return pd.concat(data) return [item for sublist in data for item in sublist] - def parallelize_batches(self, task, data, args, method=None, timeout=-1, input_split=False, - merge_output= True, minibatch_size= None, procs=None): + def process_batches(self, task, data, args, backend=None, backend_handle=None, input_split=False, + merge_output= True, minibatch_size= None, procs=None, verbose= None): """ Parameters @@ -133,31 +139,40 @@ def parallelize_batches(self, task, data, args, method=None, timeout=-1, input_s args: list Arguments to pass to the specified function following the mini-batch - method: {'serial', 'multiprocessing', 'threading'} - Method to apply task to the batches - - - 'serial' will apply task sequentially to the batches without additional parallelism. - - - 'multiprocessing' will apply task to batches in parallel using multiple processes. - - - 'threading' will apply task to batches in parallel using multiple threads. - - timeout: int | float - Maximum time(seconds) to wait until the jobs finish or until a timeout occurs. - input_split: boolean, default False If True, input data is already mapped into minibatches, otherwise data will be split on call. merge_output: boolean, default True If True, results from minibatches will be reduced into one single instance before return. + procs: int + Number of process(es)/thread(s) for executing task in parallel. Used for multiprocessing, threading and Loky + minibatch_size: int - Expected size of each mini-batch to individually perform task on. The actual sizes will be - the same as the specified value except the last mini-batch, whose size might be exactly the same - as this value or smaller. + Expected size of each minibatch - procs: int - Number of process(es)/thread(s) to use to execute task in parallel. + backend: {'serial', 'multiprocessing', 'threading', 'loky', 'spark', 'dask', 'ray'} + Backend for computing the tasks + + - 'serial' sequential execution without a backend scheduler + + - 'multiprocessing' Python standard multiprocessing library + + - 'threading' Python standard threading library + + - 'loky' Loky fork of multiprocessing library + + - 'spark' PySpark local or distributed execution + + - 'dask' Dask Distributed local or distributed execution + + - 'ray' Ray local or distributed execution + + backend_handle: object + Backend handle for sending tasks + + verbose: int + Verbosity level. Setting verbose > 0 will display additional information depending on the specific level set. Returns ------- @@ -166,65 +181,66 @@ def parallelize_batches(self, task, data, args, method=None, timeout=-1, input_s the dataset, with each entry as a sample. Otherwise this will be a list of list-like objects, with each entry representing the results from a minibatch. """ - if procs==None: procs= self.procs - if method == None: method= self.method - if self.verbose > 1: - print("Parallel task:", task, " Method:", method, " Procs:", self.procs, - " input_split:", input_split, " merge_output:", merge_output) - if method=="spark": - #Assumes data is in column 1 of the RDD, as generated by lists2rddbatches - rdd_col= 1 - if not (input_split): data= self.lists2rddbatches(data) - def apply_func(batch): return batch[:rdd_col]+[task([batch[rdd_col]]+args)]+batch[rdd_col+1:] - results= data.map(apply_func) - if merge_output: results= self.rddbatches2lists(results) - return results - - if self.verbose> 10: + if procs is None: procs= self.procs + if backend is None: backend= self.backend + if backend_handle is None: backend_handle = self.backend + if verbose is None: verbose= self.verbose + if verbose > 1: + print("Parallel task:", task, " backend:", backend, " procs:", self.procs, + " input_split:", input_split, " merge_output:", merge_output) + + if verbose> 10: print("len(data):", len(data), "len(args):", len(args), "[type(x) for x in data]:", [type(x) for x in data], "[type(x) for x in args]:", [type(x) for x in args]) - if timeout==-1: timeout= self.timeout - attempt= 0 if not(input_split): - paral_params= [[data_batch]+ args for data_batch in self.split_batches(data, minibatch_size)] + if backend=="spark": + paral_params= self.split_batches(data, minibatch_size, backend="spark") + else: + paral_params= [[data_batch]+ args for data_batch in self.split_batches(data, minibatch_size)] else: - paral_params= [[data_batch]+ args for data_batch in data] - if self.verbose > 1: print("Start task, len(paral_params)", len(paral_params)) - if method == "serial": + if backend!="spark": paral_params= [[data_batch]+ args for data_batch in data] + else: paral_params= data + if verbose > 1: print("Start task, len(paral_params)", len(paral_params)) + if backend == "serial": results = [task(minibatch) for minibatch in paral_params] else: - while (attempt != -1): - try: - if method=="multiprocessing": - with closing(multiprocessing.Pool(max(1, procs), maxtasksperchild=2)) as pool: - results= pool.map_async(task, paral_params) - if timeout==0: - pool.close() - pool.join() - results= results.get() - else: - results.wait(timeout=timeout) - if results.ready(): results= results.get() - else: raise ValueError('Parallelization timeout') - elif method=="threading": - with closing(multiprocessing.dummy.Pool(max(1,procs))) as pool: - results= pool.map(task, paral_params) - pool.close() - pool.join() - #elif method == "parallelpython": - # job_server= pp.Server() - # jobs= [job_server.submit(task, (x,), (), ()) for x in paral_params] - # results= [x() for x in jobs] - except: - print("Parallelization fail. Method:", method, "Task:", task) - attempt+= 1 - if timeout!=0: timeout*= 2 - if attempt>=5: return None - print("Retrying, attempt:", attempt, "timeout limit:", timeout, "seconds") - continue - attempt= -1 - if merge_output: return self.merge_batches(results) + if backend=="multiprocessing": + with closing(multiprocessing.Pool(max(1, procs), maxtasksperchild=2)) as pool: + results = pool.map_async(task, paral_params) + pool.close() + pool.join() + results= results.get() + elif backend=="threading": + with closing(multiprocessing.dummy.Pool(max(1,procs))) as pool: + results= pool.map(task, paral_params) + pool.close() + pool.join() + if backend=="loky": + from loky import get_reusable_executor + pool= get_reusable_executor(max_workers=max(1, procs)) + results= list(pool.map(task, paral_params)) + elif backend == "dask": + ###if not (input_split): data= self.scatter(data) + results = [self.backend_handle.submit(task, params) for params in paral_params] + elif backend == "spark": + def apply_func_to_indexedrdd(batch): + return [batch[0]] + [task([batch[1]] + args)] + results = paral_params.map(apply_func_to_indexedrdd) + elif backend == "ray": + import ray + @ray.remote + def f_ray(f, data): + return f(data) + results = [f_ray.remote(task, params) for params in paral_params] + results = [self.backend_handle.get(x) for x in results] #Slower, but handles edge cases + #results= self.backend_handle.get(results) #Faster, but crashes on edge cases? + #results = self.backend_handle.get([f_ray.remote(task, params) for params in paral_params]) + #ppft currently not supported. Supporting arbitrary tasks requires modifications to passed arguments + #elif backend == "ppft": + # jobs = [self.backend_handle.submit(task, (x,), (), ()) for x in paral_params] + # results = [x() for x in jobs] + if merge_output: return self.merge_batches(self.collect_batches(results, backend=backend)) return results def shuffle_batch(self, texts, labels= None, seed= None): diff --git a/wordbatch/extractors/extractors.pyx b/wordbatch/extractors/extractors.pyx index f97053d..b5af241 100644 --- a/wordbatch/extractors/extractors.pyx +++ b/wordbatch/extractors/extractors.pyx @@ -4,7 +4,6 @@ from __future__ import with_statement from __future__ import division from __future__ import absolute_import from __future__ import print_function -import types from sklearn.utils.murmurhash import murmurhash3_32 from sklearn.feature_extraction.text import HashingVectorizer #from nltk.metrics import edit_distance @@ -14,9 +13,7 @@ import numpy as np import gzip import lz4framed import array -import sys from wordbatch.data_utils import indlist2csrmatrix -import wordbatch.transformers.dictionary from cpython cimport array cimport cython from libc.stdlib cimport abs @@ -69,23 +66,21 @@ cdef class TextRow: self.fea_weights[index]= weight class WordBag: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary - fea_cfg.setdefault("norm", 'l2') - fea_cfg.setdefault("tf", 'log') - fea_cfg.setdefault("idf", 0.0) - fea_cfg.setdefault("hash_ngrams", 0) - fea_cfg.setdefault("hash_ngrams_weights", None) - fea_cfg.setdefault("hash_size", 10000000) - fea_cfg.setdefault("hash_polys_window", 0) - fea_cfg.setdefault("hash_polys_mindf", 5) - fea_cfg.setdefault("hash_polys_maxdf", 0.5) - fea_cfg.setdefault("hash_polys_weight", 0.1) - fea_cfg.setdefault("seed", 0) - for key, value in fea_cfg.items(): setattr(self, key.lower(), value) - if self.hash_ngrams_weights==None: self.hash_ngrams_weights= [1.0 for x in range(self.hash_ngrams)] - if self.hash_ngrams== 0: self.hash_size= self.dictionary.max_words + def __init__(self, *args, **kwargs): + self.dictionary= kwargs.get('dictionary', None) + kwargs.setdefault("norm", 'l2') + kwargs.setdefault("tf", 'log') + kwargs.setdefault("idf", 0.0) + kwargs.setdefault("hash_ngrams", 0) + kwargs.setdefault("hash_ngrams_weights", None) + kwargs.setdefault("hash_size", 10000000) + kwargs.setdefault("hash_polys_window", 0) + kwargs.setdefault("hash_polys_mindf", 5) + kwargs.setdefault("hash_polys_maxdf", 0.5) + kwargs.setdefault("hash_polys_weight", 0.1) + kwargs.setdefault("seed", 0) + for key, value in kwargs.items(): setattr(self, key.lower(), value) + if self.hash_ngrams_weights is None: self.hash_ngrams_weights= [1.0 for _ in range(self.hash_ngrams)] def transform_single(self, text): dft= self.dictionary.dft @@ -93,11 +88,12 @@ class WordBag: cdef int fc_hash_ngrams= self.hash_ngrams, word_id, df= 1, df2, hashed, doc_count= self.dictionary.doc_count, \ use_idf= 0, seed= self.seed cdef float idf_lift= 0.0, idf= 1.0, weight, norm= 1.0, norm_idf= 1.0 - if self.idf!= None: + if self.idf is not None: use_idf= True idf_lift= self.idf norm_idf= 1.0 / log(max(1.0, idf_lift + doc_count)) cdef int fc_hash_size= self.hash_size + if self.hash_ngrams == 0: hash_size = self.dictionary.max_words fc_hash_ngrams_weights= self.hash_ngrams_weights fc_tf= self.tf fc_norm= self.norm @@ -108,7 +104,7 @@ class WordBag: cdef TextRow textrow= TextRow() for x in range(len(text)): word= text[x] - if len(word2id)!=0: + if word2id is not None: word_id = word2id.get(word, -1) if word_id == -1 or word_id>= fc_hash_size: continue df= dft.get(word, 0) @@ -165,19 +161,18 @@ class WordBag: elif fc_norm== 'l1': norm= np.sum(np.abs(data_view)) elif fc_norm== 'l2': norm= np.sqrt(np.sum([w*w for w in data_view])) if norm != 0.0: norm= 1.0 / norm - if fc_norm!=None: wordbag.data*= norm + if fc_norm is not None: wordbag.data*= norm return wordbag - def batch_transform(self, texts): + def transform(self, texts, y= None):#input_split= False, merge_output= True, batcher= None): return ssp.vstack([self.transform_single(text) for text in texts]) - def transform(self, texts, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_transform, texts, [self], input_split= input_split, - merge_output= merge_output, procs= int(self.batcher.procs / 2)) - - def fit(self): + def fit(self, texts, y= None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y=None) + def save_features(self, file, features): csr_to_lz4(file, features) @@ -185,20 +180,19 @@ class WordBag: return lz4_to_csr(file) class WordHash: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary - self.hv= HashingVectorizer(**fea_cfg) + def __init__(self, *args, **kwargs): + if "dictionary" in kwargs: kwargs.pop("dictionary") + self.hv = HashingVectorizer(*args, **kwargs) - def batch_transform(self, texts): return self.hv.transform(texts) + def transform(self, texts, y=None): + return self.hv.transform(texts) - def transform(self, texts, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_transform, texts, [self], input_split= input_split, - merge_output= merge_output, procs= int(self.batcher.procs / 2)) - - def fit(self): + def fit(self, texts, y=None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y=None) + def save_features(self, file, features): csr_to_lz4(file, features) @@ -207,22 +201,21 @@ class WordHash: class WordSeq: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary - fea_cfg.setdefault("seq_maxlen", None) - fea_cfg.setdefault("seq_padstart", True) - fea_cfg.setdefault("seq_truncstart", True) - fea_cfg.setdefault("remove_oovs", False) - fea_cfg.setdefault("pad_id", 0) - fea_cfg.setdefault("oov_id", dictionary.max_words+1) - for key, value in fea_cfg.items(): setattr(self, key.lower(), value) + def __init__(self, *args, **kwargs): + self.dictionary = kwargs.get('dictionary', None) + kwargs.setdefault("seq_maxlen", None) + kwargs.setdefault("seq_padstart", True) + kwargs.setdefault("seq_truncstart", True) + kwargs.setdefault("remove_oovs", False) + kwargs.setdefault("pad_id", 0) + for key, value in kwargs.items(): setattr(self, key.lower(), value) def transform_single(self, text): word2id= self.dictionary.word2id + oov_id= self.dictionary.max_words+1 if self.remove_oovs: wordseq= [word2id[word] for word in text.split(" ") if word in word2id] - else: wordseq= [word2id.get(word, self.oov_id) for word in text.split(" ")] - if self.seq_maxlen != None: + else: wordseq= [word2id.get(word, oov_id) for word in text.split(" ")] + if self.seq_maxlen is not None: if len(wordseq) > self.seq_maxlen: if self.seq_truncstart: wordseq= wordseq[-self.seq_maxlen:] else: wordseq= wordseq[:self.seq_maxlen] @@ -231,15 +224,15 @@ class WordSeq: else: wordseq+= [self.pad_id] * (self.seq_maxlen - len(wordseq)) return wordseq - def batch_transform(self, texts): return [self.transform_single(text) for text in texts] - - def transform(self, texts, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_transform, texts, [self], input_split=input_split, - merge_output=merge_output, procs= int(self.batcher.procs / 2)) + def transform(self, texts, y= None): + return [self.transform_single(text) for text in texts] - def fit(self): + def fit(self, texts, y=None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y=None) + def save_features(self, file, features): save_to_lz4(file, features, dtype=int) i= 0 @@ -256,24 +249,22 @@ class WordSeq: class WordVec: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary - fea_cfg.setdefault("normalize_text", None) - fea_cfg.setdefault("stemmer", None) - fea_cfg.setdefault("merge_dict", True) - fea_cfg.setdefault("normalize_dict", False) - fea_cfg.setdefault("verbose", 0) - fea_cfg.setdefault("merge_vectors", "mean") - fea_cfg.setdefault("normalize_merged", "l2") - fea_cfg.setdefault("encoding", "utf8") - fea_cfg.setdefault("shrink_model_transform", True) - fea_cfg.setdefault("w2v_dim", None) - for key, value in fea_cfg.items(): setattr(self, key.lower(), value) - if "w2v_model" in fea_cfg: self.w2v= fea_cfg["w2v_model"] - else: self.w2v= self.load_w2v(fea_cfg["wordvec_file"], fea_cfg['encoding'], fea_cfg['w2v_dim']) + def __init__(self, *args, **kwargs): + self.dictionary = kwargs.get('dictionary', None) + kwargs.setdefault("normalize_text", None) + kwargs.setdefault("stemmer", None) + kwargs.setdefault("merge_dict", True) + kwargs.setdefault("normalize_dict", False) + kwargs.setdefault("verbose", 0) + kwargs.setdefault("merge_vectors", "mean") + kwargs.setdefault("normalize_merged", "l2") + kwargs.setdefault("encoding", "utf8") + kwargs.setdefault("shrink_model_transform", True) + kwargs.setdefault("w2v_dim", None) + for key, value in kwargs.items(): setattr(self, key.lower(), value) + if "w2v_model" in kwargs: self.w2v= kwargs["w2v_model"] + else: self.w2v= self.load_w2v(kwargs["wordvec_file"], kwargs['encoding'], kwargs['w2v_dim']) self.w2v_dim= len(list(self.w2v.values())[0]) - self.fea_cfg= fea_cfg def load_w2v(self, w2v_file, encoding= "ISO-8859-1", w2v_dim= None): w2v= {} @@ -288,8 +279,8 @@ class WordVec: print("Wrong vector length", len(vec),", should be:", w2v_dim, ":", line) continue word= line[0] - if self.normalize_text!=None: word= self.normalize_text(word) - if self.stemmer!=None: word= self.stemmer.stem(word) + if self.normalize_text is not None: word= self.normalize_text(word) + if self.stemmer is not None: word= self.stemmer.stem(word) if not(self.merge_dict): w2v[word]= vec else: w2v_counts[word] += 1 @@ -316,10 +307,10 @@ class WordVec: for word in text: if word in w2v: vecs.append(w2v[word]) else: vecs.append(np.zeros(self.w2v_dim)) - if self.merge_vectors!=None: #Merge word vectors to a per-document vector + if self.merge_vectors is not None: #Merge word vectors to a per-document vector if self.merge_vectors=="mean": #Currently only mean vector suppported, could do max, median, etc. vec= np.mean(vecs, axis=0) - if self.normalize_merged!=None: #l1 and l2 normalization supported + if self.normalize_merged is not None: #l1 and l2 normalization supported if self.normalize_merged == "l1": norm = sum(np.abs(vec)) else: @@ -329,63 +320,56 @@ class WordVec: return vec return vecs - def batch_transform(self, texts): + def transform(self, texts, y=None): + #if batcher is None: return batch_transform(texts) + # if self.shrink_model_transform == True: + # #Send only word vectors occurring in texts to parallel processes. + # #Use to reduce memory footprint with big embedding files. + # d= wordbatch.transformers.dictionary.Dictionary(verbose=0, encode=False).fit(texts, input_split=input_split) + # w2v_model2= {x:self.w2v[x] for x in [z for z in self.w2v.keys() if z in d.dft]} + # fea_cfg2= self.fea_cfg + # fea_cfg2['w2v_model']= w2v_model2 + # self_shrunk= WordVec(dictionary=None, fea_cfg=fea_cfg2) + # else: self_shrunk= self + #return batcher.process_batches(batch_transform, texts, [self], input_split=input_split, + # merge_output=merge_output) return [self.transform_single(text) for text in texts] - def transform(self, texts, input_split= False, merge_output= True): - if self.fea_cfg['shrink_model_transform']== True: - #Send only word vectors occurring in texts to parallel processes. - #Use to reduce memory footprint with big embedding files. - d= wordbatch.transformers.dictionary.Dictionary(batcher=self.batcher, verbose=0, encode=False)\ - .fit(texts, input_split=input_split) - w2v_model2= {x:self.w2v[x] for x in [z for z in self.w2v.keys() if z in d.dft]} - fea_cfg2= self.fea_cfg - fea_cfg2['w2v_model']= w2v_model2 - self_shrunk= WordVec(batcher=self.batcher, dictionary=None, fea_cfg=fea_cfg2) - else: self_shrunk= self - return self.batcher.parallelize_batches(batch_transform, texts, [self_shrunk], input_split=input_split, - merge_output=merge_output, procs= int(self.batcher.procs / 2)) - - def fit(self): + def fit(self, texts, y=None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y) -class Hstack: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary - t = [x[0](batcher, dictionary, x[1]) for x in fea_cfg] - self.extractors= list(t) - - def transform_single(self, text): - return sp.hstack([x.transform_single(text) for x in self.extractors]) - def batch_transform(self, texts): return [self.transform_single(text) for text in texts] +class Hstack: + def __init__(self, extractors): + self.extractors= extractors - def transform(self, texts, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_transform, texts, [self], input_split=input_split, - merge_output=merge_output, procs= int(self.batcher.procs / 2)) + def transform(self, texts, y= None): + return sp.hstack([x.transform(texts) for x in self.extractors]) - def fit(self): + def fit(self, texts, y=None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y) + class PandasHash: - def __init__(self, batcher, dictionary, fea_cfg): - self.batcher= batcher - self.dictionary= dictionary + def __init__(self, *args, **kwargs): self.col_salt= None self.col_weight= None self.col_pick= [] #self.col_type= [] - for key, value in fea_cfg.items(): setattr(self, key.lower(), value) + for key, value in kwargs.items(): setattr(self, key.lower(), value) if self.col_salt is None: self.col_salt = ["".join([z[0] for z in x.replace(" ", "_").replace("|", "_").split("_")]) - for x in self.col_pick] + for x in self.col_pick] if self.col_weight is None: self.col_weight = np.ones(len(self.col_pick)) #if self.col_type is None: self.col_type = ["cat"]*len(self.col_pick) - def batch_transform(self, df): + def transform(self, df, y= None): D= self.n_features col_pick= self.col_pick col_salt= self.col_salt @@ -399,10 +383,9 @@ class PandasHash: datalist= [col_weight] * len(df), shape= (len(df), D)) - def transform(self, texts, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_transform, texts, [self], input_split=input_split, - merge_output=merge_output, procs= int(self.batcher.procs / 2)) - - def fit(self): + def fit(self, texts, y=None): return self + def fit_transform(self, texts, y=None): + return self.transform(texts, y) + diff --git a/wordbatch/feature_union.py b/wordbatch/feature_union.py deleted file mode 100644 index 73a3300..0000000 --- a/wordbatch/feature_union.py +++ /dev/null @@ -1,249 +0,0 @@ -from multiprocessing.pool import Pool - -import numpy as np -from scipy import sparse -from sklearn.base import TransformerMixin -from sklearn.pipeline import FeatureUnion, _fit_one_transformer, _fit_transform_one, _transform_one, _name_estimators -from sklearn.utils.metaestimators import _BaseComposition - -class FeatureUnion(_BaseComposition, TransformerMixin): - """Concatenates results of multiple transformer objects. - - This estimator applies a list of transformer objects in parallel to the - input data, then concatenates the results. This is useful to combine - several feature extraction mechanisms into a single transformer. - - Parameters of the transformers may be set using its name and the parameter - name separated by a '__'. A transformer may be replaced entirely by - setting the parameter with its name to another transformer, - or removed by setting to ``None``. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - transformer_list : list of (string, transformer) tuples - List of transformer objects to be applied to the data. The first - half of each tuple is the name of the transformer. - - n_jobs : int, optional - Number of jobs to run in parallel (default 1). - - transformer_weights : dict, optional - Multiplicative weights for features per transformer. - Keys are transformer names, values the weights. - - """ - - def __init__(self, transformer_list, n_jobs=1, transformer_weights=None, concatenate= True): - self.transformer_list = transformer_list - self.n_jobs = n_jobs - self.transformer_weights = transformer_weights - self._validate_transformers() - self.concatenate= concatenate - - def get_params(self, deep=True): - """Get parameters for this estimator. - - Parameters - ---------- - deep : boolean, optional - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : mapping of string to any - Parameter names mapped to their values. - """ - return self._get_params('transformer_list', deep=deep) - - def set_params(self, **kwargs): - """Set the parameters of this estimator. - - Valid parameter keys can be listed with ``get_params()``. - - Returns - ------- - self - """ - self._set_params('transformer_list', **kwargs) - return self - - def _validate_transformers(self): - names, transformers = zip(*self.transformer_list) - - # validate names - self._validate_names(names) - - # validate estimators - for t in transformers: - if t is None: - continue - if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not - hasattr(t, "transform")): - raise TypeError("All estimators should implement fit and " - "transform. '%s' (type %s) doesn't" % - (t, type(t))) - - def _iter(self): - """Generate (name, est, weight) tuples excluding None transformers - """ - get_weight = (self.transformer_weights or {}).get - return ((name, trans, get_weight(name)) - for name, trans in self.transformer_list - if trans is not None) - - def get_feature_names(self): - """Get feature names from all transformers. - - Returns - ------- - feature_names : list of strings - Names of the features produced by transform. - """ - feature_names = [] - for name, trans, weight in self._iter(): - if not hasattr(trans, 'get_feature_names'): - raise AttributeError("Transformer %s (type %s) does not " - "provide get_feature_names." - % (str(name), type(trans).__name__)) - feature_names.extend([name + "__" + f for f in - trans.get_feature_names()]) - return feature_names - - def fit(self, X, y=None): - """Fit all transformers using X. - - Parameters - ---------- - X : iterable or array-like, depending on transformers - Input data, used to fit transformers. - - y : array-like, shape (n_samples, ...), optional - Targets for supervised learning. - - Returns - ------- - self : FeatureUnion - This estimator - """ - self.transformer_list = list(self.transformer_list) - self._validate_transformers() - with Pool(self.n_jobs) as pool: - transformers = pool.starmap(_fit_one_transformer, - ((trans, X[trans['col_pick']] if hasattr(trans, 'col_pick') else X, y) for _, trans, _ in self._iter())) - self._update_transformer_list(transformers) - return self - - def fit_transform(self, X, y=None, **fit_params): - """Fit all transformers, transform the data and concatenate results. - - Parameters - ---------- - X : iterable or array-like, depending on transformers - Input data to be transformed. - - y : array-like, shape (n_samples, ...), optional - Targets for supervised learning. - - Returns - ------- - X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) - hstack of results of transformers. sum_n_components is the - sum of n_components (output dimension) over transformers. - """ - self._validate_transformers() - with Pool(self.n_jobs) as pool: - result = pool.starmap(_fit_transform_one, - ((trans, weight, X[trans['col_pick']] if hasattr(trans, 'col_pick') else X, y) - for name, trans, weight in self._iter())) - if not result: - # All transformers are None - return np.zeros((X.shape[0], 0)) - Xs, transformers = zip(*result) - self._update_transformer_list(transformers) - if self.concatenate: - if any(sparse.issparse(f) for f in Xs): - Xs = sparse.hstack(Xs).tocsr() - else: - Xs = np.hstack(Xs) - return Xs - - def transform(self, X): - """Transform X separately by each transformer, concatenate results. - - Parameters - ---------- - X : iterable or array-like, depending on transformers - Input data to be transformed. - - Returns - ------- - X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) - hstack of results of transformers. sum_n_components is the - sum of n_components (output dimension) over transformers. - """ - with Pool(self.n_jobs) as pool: - Xs = pool.starmap(_transform_one, ((trans, weight, X[trans['col_pick']] if hasattr(trans, 'col_pick') - else X) for name, trans, weight in self._iter())) - if not Xs: - # All transformers are None - return np.zeros((X.shape[0], 0)) - if self.concatenate: - if any(sparse.issparse(f) for f in Xs): - Xs = sparse.hstack(Xs).tocsr() - else: - Xs = np.hstack(Xs) - return Xs - - def _update_transformer_list(self, transformers): - transformers = iter(transformers) - self.transformer_list[:] = [ - (name, None if old is None else next(transformers)) - for name, old in self.transformer_list - ] - - -def make_union(*transformers, **kwargs): - """Construct a FeatureUnion from the given transformers. - - This is a shorthand for the FeatureUnion constructor; it does not require, - and does not permit, naming the transformers. Instead, they will be given - names automatically based on their types. It also does not allow weighting. - - Parameters - ---------- - *transformers : list of estimators - - n_jobs : int, optional - Number of jobs to run in parallel (default 1). - - Returns - ------- - f : FeatureUnion - - Examples - -------- - >>> from sklearn.decomposition import PCA, TruncatedSVD - >>> from sklearn.pipeline import make_union - >>> make_union(PCA(), TruncatedSVD()) # doctest: +NORMALIZE_WHITESPACE - FeatureUnion(n_jobs=1, - transformer_list=[('pca', - PCA(copy=True, iterated_power='auto', - n_components=None, random_state=None, - svd_solver='auto', tol=0.0, whiten=False)), - ('truncatedsvd', - TruncatedSVD(algorithm='randomized', - n_components=2, n_iter=5, - random_state=None, tol=0.0))], - transformer_weights=None) - """ - n_jobs = kwargs.pop('n_jobs', 1) - concatenate = kwargs.pop('concatenate', True) - if kwargs: - # We do not currently support `transformer_weights` as we may want to - # change its type spec in make_union - raise TypeError('Unknown keyword arguments: "{}"' - .format(list(kwargs.keys())[0])) - return FeatureUnion(_name_estimators(transformers), n_jobs= n_jobs, concatenate= concatenate) diff --git a/wordbatch/pipelines/__init__.py b/wordbatch/pipelines/__init__.py new file mode 100644 index 0000000..1de4174 --- /dev/null +++ b/wordbatch/pipelines/__init__.py @@ -0,0 +1,5 @@ +from .apply import Apply +from .apply_batch import ApplyBatch +from .batch_transformer import BatchTransformer +from .feature_union import FeatureUnion +from .wordbatch import WordBatch \ No newline at end of file diff --git a/wordbatch/pipelines/apply.py b/wordbatch/pipelines/apply.py new file mode 100644 index 0000000..055272c --- /dev/null +++ b/wordbatch/pipelines/apply.py @@ -0,0 +1,41 @@ +#!python +from __future__ import with_statement +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +import pandas as pd +import wordbatch.batcher + +def batch_transform(args): + f= args[1] + f_args= args[2] + f_kwargs= args[3] + #Applying per DataFrame row is very slow, use ApplyBatch instead + if isinstance(args[0], pd.DataFrame): return args[0].apply(lambda x: f(x, *f_args, **f_kwargs), axis=1) + return [f(row, *f_args, **f_kwargs) for row in args[0]] + +class Apply(object): + #Applies a function to each row of a minibatch + def __init__(self, function, batcher=None, args=[], kwargs={}): + if batcher is None: self.batcher= wordbatch.batcher.Batcher() + else: self.batcher= batcher + self.function= function + self.args= [args] + self.kwargs= [kwargs] + + def fit(self, data, input_split= False, batcher= None): + return self + + def fit_transform(self, data, input_split= False, merge_output= True, minibatch_size= None, batcher= None): + return self.transform(data, input_split, merge_output, minibatch_size, batcher) + + def transform(self, data, input_split= False, merge_output= True, minibatch_size= None, batcher= None): + if batcher is None: batcher = self.batcher + return batcher.process_batches(batch_transform, data, [self.function] + self.args + self.kwargs, + input_split=input_split, merge_output=merge_output, + minibatch_size= minibatch_size) +# import wordbatch.batcher as batcher +# b= batcher.Batcher(minibatch_size=2)#, method="serial") +# import numpy as np +# a= Apply(np.power, b, [2],{}) +# print(a.transform([1, 2, 3, 4])) \ No newline at end of file diff --git a/wordbatch/pipelines/apply_batch.py b/wordbatch/pipelines/apply_batch.py new file mode 100644 index 0000000..70026da --- /dev/null +++ b/wordbatch/pipelines/apply_batch.py @@ -0,0 +1,41 @@ +#!python +from __future__ import with_statement +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +import wordbatch.batcher + +def batch_transform(args): + f= args[1] + f_args= args[2] + f_kwargs= args[3] + return f(args[0], *f_args, **f_kwargs) + +class ApplyBatch(object): + #Applies a function to the entire minibatch. Use this for example on Pandas dataframes, to avoid per-row overhead. + #Function needs to be applicable to the array/list of values! + #If not, modify/wrap the function to process a list, or use Apply + def __init__(self, function, batcher=None, args=[], kwargs={}): + if batcher is None: self.batcher = wordbatch.batcher.Batcher() + else: self.batcher = batcher + self.function= function + self.args= [args] + self.kwargs= [kwargs] + + def fit(self, data, input_split= False): + return self + + def fit_transform(self, data, input_split=False, merge_output=True, minibatch_size=None, batcher=None): + return self.transform(data, input_split, merge_output, minibatch_size, batcher) + + def transform(self, data, input_split=False, merge_output=True, minibatch_size=None, batcher=None): + if batcher is None: batcher = self.batcher + return batcher.process_batches(batch_transform, data, [self.function] + self.args + self.kwargs, + input_split=input_split, merge_output=merge_output, + minibatch_size= minibatch_size) + +# import wordbatch.batcher as batcher +# b= batcher.Batcher(minibatch_size=2)#, method="serial") +# import numpy as np +# a= ApplyBatch(np.power, b, [2],{}) +# print(a.transform([1, 2, 3, 4])) \ No newline at end of file diff --git a/wordbatch/pipelines/batch_transformer.py b/wordbatch/pipelines/batch_transformer.py new file mode 100644 index 0000000..d385127 --- /dev/null +++ b/wordbatch/pipelines/batch_transformer.py @@ -0,0 +1,31 @@ +#!python +from __future__ import with_statement +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +import wordbatch.batcher + +def batch_transform(args): + return args[1].transform(args[0]) + +class BatchTransformer(object): + def __init__(self, transformer, call_fit=False, batcher=None): + if batcher is None: self.batcher = wordbatch.batcher.Batcher() + else: self.batcher = batcher + self.transformer= transformer + self.call_fit= call_fit + + def fit(self, data, input_split=False): + if self.call_fit: + if input_split: self.transformer.fit(self.batcher.merge_batches(self.collect_batches(data))) + else: self.transformer.fit(data) + return self + + def fit_transform(self, data, input_split=False, merge_output=True, minibatch_size=None): + if self.call_fit: self.fit(data, input_split= input_split) + return self.transform(data, input_split, merge_output, minibatch_size= minibatch_size) + + def transform(self, data, input_split=False, merge_output=True, minibatch_size=None): + return self.batcher.process_batches(batch_transform, data, [self.transformer], + input_split=input_split, merge_output=merge_output, + minibatch_size= minibatch_size) diff --git a/wordbatch/pipelines/feature_union.py b/wordbatch/pipelines/feature_union.py new file mode 100644 index 0000000..c329aa9 --- /dev/null +++ b/wordbatch/pipelines/feature_union.py @@ -0,0 +1,287 @@ +from multiprocessing.pool import Pool + +import numpy as np +from scipy import sparse +from sklearn.base import TransformerMixin +from sklearn.pipeline import FeatureUnion, _fit_one, _fit_transform_one, _transform_one, _name_estimators +from sklearn.utils.metaestimators import _BaseComposition +from wordbatch.pipelines import Apply + +def fit_one(args): + X, y, transformer, fit_params = args + return transformer.fit(X, y, **fit_params) + +def transform_one(args): + X, transformer= args + return transformer.transform(X) + +def fit_transform_one(args): + X, y, transformer, fit_params = args + return transformer.fit_transform(X, y, **fit_params), transformer + +class FeatureUnion(_BaseComposition, TransformerMixin): + """Concatenates results of multiple transformer objects. + + This estimator applies a list of transformer objects in parallel to the + input data, then concatenates the results. This is useful to combine + several feature extraction mechanisms into a single transformer. + + Parameters of the transformers may be set using its name and the parameter + name separated by a '__'. A transformer may be replaced entirely by + setting the parameter with its name to another transformer, + or removed by setting to 'drop' or ``None``. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + transformer_list : list of (string, transformer) tuples + List of transformer objects to be applied to the data. The first + half of each tuple is the name of the transformer. + + n_jobs : int or None, optional (default=None) + Number of jobs to run in parallel. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See :term:`Glossary ` + for more details. + + transformer_weights : dict, optional + Multiplicative weights for features per transformer. + Keys are transformer names, values the weights. + + verbose : boolean, optional(default=False) + If True, the time elapsed while fitting each transformer will be + printed as it is completed. + + See also + -------- + sklearn.pipeline.make_union : convenience function for simplified + feature union construction. + + Examples + -------- + >>> from sklearn.pipeline import FeatureUnion + >>> from sklearn.decomposition import PCA, TruncatedSVD + >>> union = FeatureUnion([("pca", PCA(n_components=1)), + ... ("svd", TruncatedSVD(n_components=2))]) + >>> X = [[0., 1., 3], [2., 2., 5]] + >>> union.fit_transform(X) + array([[ 1.5 , 3.0..., 0.8...], + [-1.5 , 5.7..., -0.4...]]) + """ + _required_parameters = ["transformer_list"] + + def __init__(self, transformer_list, transformer_weights=None, batcher=None, concatenate=True): + self.transformer_list = transformer_list + self.transformer_weights = transformer_weights + self._validate_transformers() + self.batcher = batcher + self.concatenate= concatenate + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : boolean, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + return self._get_params('transformer_list', deep=deep) + + def set_params(self, **kwargs): + """Set the parameters of this estimator. + + Valid parameter keys can be listed with ``get_params()``. + + Returns + ------- + self + """ + self._set_params('transformer_list', **kwargs) + return self + + def _validate_transformers(self): + names, transformers = zip(*self.transformer_list) + + # validate names + self._validate_names(names) + + # validate estimators + for t in transformers: + if t is None or t == 'drop': + continue + if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not + hasattr(t, "transform")): + raise TypeError("All estimators should implement fit and " + "transform. '%s' (type %s) doesn't" % + (t, type(t))) + + def _iter(self): + """ + Generate (name, trans, weight) tuples excluding None and + 'drop' transformers. + """ + get_weight = (self.transformer_weights or {}).get + return ((name, trans, get_weight(name)) + for name, trans in self.transformer_list + if trans is not None and trans != 'drop') + + def get_feature_names(self): + """Get feature names from all transformers. + + Returns + ------- + feature_names : list of strings + Names of the features produced by transform. + """ + feature_names = [] + for name, trans, weight in self._iter(): + if not hasattr(trans, 'get_feature_names'): + raise AttributeError("Transformer %s (type %s) does not " + "provide get_feature_names." + % (str(name), type(trans).__name__)) + feature_names.extend([name + "__" + f for f in + trans.get_feature_names()]) + return feature_names + + def fit(self, X, y=None, **fit_params): + """Fit all transformers using X. + + Parameters + ---------- + X : iterable or array-like, depending on transformers + Input data, used to fit transformers. + + y : array-like, shape (n_samples, ...), optional + Targets for supervised learning. + + Returns + ------- + self : FeatureUnion + This estimator + """ + self.transformer_list = list(self.transformer_list) + self._validate_transformers() + paral_params = [[X[t['col_pick']] if hasattr(t, 'col_pick') else X, y, t, fit_params] + for _, t, _ in self._iter()] + transformers= Apply(fit_one, self.batcher).transform(paral_params) + # with Pool(self.n_jobs) as pool: + # transformers = pool.starmap(_fit_one, + # ((trans, X[trans['col_pick']] if hasattr(trans, 'col_pick') else X, y) for _, trans, _ in self._iter())) + self._update_transformer_list(transformers) + return self + + def fit_transform(self, X, y=None, **fit_params): + """Fit all transformers, transform the data and concatenate results. + + Parameters + ---------- + X : iterable or array-like, depending on transformers + Input data to be transformed. + + y : array-like, shape (n_samples, ...), optional + Targets for supervised learning. + + Returns + ------- + X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. + """ + self._validate_transformers() + paral_params = [[X[t['col_pick']] if hasattr(t, 'col_pick') else X, y, t, fit_params] + for _, t, _ in self._iter()] + result = Apply(fit_transform_one, self.batcher).transform(paral_params) + if not result: + # All transformers are None + return np.zeros((X.shape[0], 0)) + Xs, transformers = zip(*result) + self._update_transformer_list(transformers) + if self.concatenate: + if any(sparse.issparse(f) for f in Xs): + Xs = sparse.hstack(Xs).tocsr() + else: + Xs = np.hstack(Xs) + return Xs + + def transform(self, X): + """Transform X separately by each transformer, concatenate results. + + Parameters + ---------- + X : iterable or array-like, depending on transformers + Input data to be transformed. + + Returns + ------- + X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. + """ + paral_params = [[X[t['col_pick']] if hasattr(t, 'col_pick') else X, t] for _, t, _ in self._iter()] + Xs = Apply(transform_one, self.batcher).transform(paral_params) + if not Xs: + # All transformers are None + return np.zeros((X.shape[0], 0)) + if self.concatenate: + if any(sparse.issparse(f) for f in Xs): + Xs = sparse.hstack(Xs).tocsr() + else: + Xs = np.hstack(Xs) + return Xs + + def _update_transformer_list(self, transformers): + transformers = iter(transformers) + self.transformer_list[:] = [(name, None if old is None or old == 'drop' else next(transformers)) + for name, old in self.transformer_list + ] + + +def make_union(*transformers, **kwargs): + """Construct a FeatureUnion from the given transformers. + + This is a shorthand for the FeatureUnion constructor; it does not require, + and does not permit, naming the transformers. Instead, they will be given + names automatically based on their types. It also does not allow weighting. + + Parameters + ---------- + *transformers : list of estimators + + n_jobs : int, optional + Number of jobs to run in parallel (default 1). + + Returns + ------- + f : FeatureUnion + + Examples + -------- + >>> from sklearn.decomposition import PCA, TruncatedSVD + >>> from sklearn.pipeline import make_union + >>> make_union(PCA(), TruncatedSVD()) # doctest: +NORMALIZE_WHITESPACE + FeatureUnion(n_jobs=1, + transformer_list=[('pca', + PCA(copy=True, iterated_power='auto', + n_components=None, random_state=None, + svd_solver='auto', tol=0.0, whiten=False)), + ('truncatedsvd', + TruncatedSVD(algorithm='randomized', + n_components=2, n_iter=5, + random_state=None, tol=0.0))], + transformer_weights=None) + """ + n_jobs = kwargs.pop('n_jobs', 1) + concatenate = kwargs.pop('concatenate', True) + if kwargs: + # We do not currently support `transformer_weights` as we may want to + # change its type spec in make_union + raise TypeError('Unknown keyword arguments: "{}"' + .format(list(kwargs.keys())[0])) + return FeatureUnion(_name_estimators(transformers), n_jobs= n_jobs, concatenate= concatenate) diff --git a/wordbatch/pipelines/wordbatch.py b/wordbatch/pipelines/wordbatch.py new file mode 100644 index 0000000..5a9f1ef --- /dev/null +++ b/wordbatch/pipelines/wordbatch.py @@ -0,0 +1,101 @@ +#!python +from __future__ import with_statement +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +import os +import wordbatch.batcher +import wordbatch.pipelines +import wordbatch.transformers + +class WordBatch(object): + def __init__(self, normalize_text="new", dictionary="new", tokenizer=None, extractor=None, + freeze=False, batcher=None, verbose=0): + self.verbose= verbose + if batcher is None: self.batcher= wordbatch.batcher.Batcher(verbose=verbose) + else: self.batcher= batcher + + if normalize_text is None: self.normalize_text= None + elif normalize_text== "new": self.normalize_text= wordbatch.transformers.text_normalizer.TextNormalizer() + elif callable(normalize_text): + self.normalize_text= wordbatch.pipelines.apply.Apply(normalize_text, batcher=batcher) + else: self.normalize_text= normalize_text + + if tokenizer is None: self.tokenizer= None + else: self.tokenizer= tokenizer + + if dictionary is None: self.dictionary= None + elif dictionary== "new": self.dictionary= wordbatch.transformers.dictionary.Dictionary() + else: self.dictionary= dictionary + + if extractor is None: self.extractor= None + else: self.extractor= wordbatch.pipelines.batch_transformer.BatchTransformer(extractor, batcher=batcher) + if self.extractor is not None: + if hasattr(self.extractor.transformer, "dictionary"): + self.extractor.transformer.dictionary = self.dictionary + self.freeze= freeze + + def reset(self): + if self.dictionary is not None: self.dictionary.reset() + return self + + def process(self, data, y=None, input_split=False, reset=True, update=True): + if reset: self.reset() + if self.freeze: update= False + + if self.normalize_text is not None: + if self.verbose > 0: print("Normalize text") + data= self.normalize_text.transform(data, input_split=input_split, merge_output=False, + batcher= self.batcher) + input_split= True + if self.tokenizer is not None: + if self.verbose > 0: print("Tokenize text") + if update: data= self.tokenizer.fit_transform(data, y=y, input_split=input_split, merge_output=False, + reset=reset, batcher=self.batcher) + else: data= self.tokenizer.transform(data, y=y, input_split=input_split, merge_output=False, + batcher=self.batcher) + input_split= True + if self.dictionary is not None: + if self.verbose > 0: print("Use dictionary") + if update: data= self.dictionary.fit_transform(data, y=y, input_split=input_split, merge_output=False, + reset=reset, batcher=self.batcher) + else: data= self.dictionary.transform(data, y=y, input_split=input_split, merge_output=False, + batcher=self.batcher) + if self.verbose> 2: print("len(self.dictionary.dft):", len(self.dictionary.dft)) + return data + + def fit(self, data, y=None, input_split=False, reset=True): + self.process(data, y, input_split, reset=reset, update= True) + return self + + def transform(self, data, y=None, cache_features=None, input_split=False, reset=False, update=False): + if cache_features is not None: + if self.extractor is not None and os.path.exists(cache_features) and \ + hasattr(self.extractor.transformer, "load_features"): + return self.extractor.transformer.load_features(cache_features) + if not(input_split): data= self.batcher.split_batches(data) + + data= self.process(data, y=y, input_split=True, reset=reset, update=update) + if self.extractor is not None: + if self.verbose > 0: print("Extract features") + data= self.extractor.transform(data, input_split=True, merge_output=True) + if cache_features is not None and hasattr(self.extractor.transformer, "load_features"): + self.extractor.transformer.save_features(cache_features, data) + return data + else: + return self.batcher.merge_batches(data, self.batcher.backend) + + def partial_fit(self, data, y=None, input_split=False): + return self.fit(data, y, input_split, reset=False) + + def fit_transform(self, data, y=None, cache_features=None, input_split=False, reset=True): + return self.transform(data, y, cache_features, input_split, reset, update=True) + + def partial_fit_transform(self, data, y=None, cache_features=None, input_split=False): + return self.transform(data, y, cache_features, input_split, reset=False, update=True) + + def __getstate__(self): + return dict((k, v) for (k, v) in self.__dict__.items()) + + def __setstate__(self, params): + for key in params: setattr(self, key, params[key]) diff --git a/wordbatch/transformers/__init__.py b/wordbatch/transformers/__init__.py index acd83c7..aa59256 100644 --- a/wordbatch/transformers/__init__.py +++ b/wordbatch/transformers/__init__.py @@ -1,4 +1,3 @@ -from .apply import Apply -from .apply_batch import ApplyBatch from .dictionary import Dictionary from .tokenizer import Tokenizer +from .text_normalizer import TextNormalizer diff --git a/wordbatch/transformers/apply.py b/wordbatch/transformers/apply.py deleted file mode 100644 index 5ddfd48..0000000 --- a/wordbatch/transformers/apply.py +++ /dev/null @@ -1,37 +0,0 @@ -#!python -from __future__ import with_statement -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function -import pandas as pd - -def batch_apply(args): - f= args[1] - f_args= args[2] - f_kwargs= args[3] - #Applying per DataFrame row is very slow, use ApplyBatch instead - if isinstance(args[0], pd.DataFrame): return args[0].apply(lambda x: f(x, *f_args, **f_kwargs), axis=1) - return [f(row, *f_args, **f_kwargs) for row in args[0]] - -class Apply(object): - #Applies a function to each row of a minibatch - def __init__(self, batcher, function, args=[], kwargs={}): - self.batcher= batcher - self.function= function - self.args= [args] - self.kwargs= [kwargs] - - def fit(self, data, input_split= False): - return self - - def fit_transform(self, data, input_split= False, merge_output= True): - return self.transform(data, input_split, merge_output) - - def transform(self, data, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_apply, data, [self.function]+self.args+self.kwargs, - input_split=input_split, merge_output=merge_output) -# import wordbatch.batcher as batcher -# b= batcher.Batcher(minibatch_size=2)#, method="serial") -# import numpy as np -# a= Apply(b, np.power, [2],{}) -# print(a.transform([1, 2, 3, 4])) \ No newline at end of file diff --git a/wordbatch/transformers/apply_batch.py b/wordbatch/transformers/apply_batch.py deleted file mode 100644 index 936e0b9..0000000 --- a/wordbatch/transformers/apply_batch.py +++ /dev/null @@ -1,37 +0,0 @@ -#!python -from __future__ import with_statement -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function - -def batch_apply(args): - f= args[1] - f_args= args[2] - f_kwargs= args[3] - return f(args[0], *f_args, **f_kwargs) - -class ApplyBatch(object): - #Applies a function to the entire minibatch. Use this for example on Pandas dataframes, to avoid per-row overhead. - #Function needs to be applicable to the array/list of values! - #If not, modify/wrap the function to process a list, or use Apply - def __init__(self, batcher, function, args=[], kwargs={}): - self.batcher= batcher - self.function= function - self.args= [args] - self.kwargs= [kwargs] - - def fit(self, data, input_split= False): - return self - - def fit_transform(self, data, input_split= False, merge_output= True): - return self.transform(data, input_split, merge_output) - - def transform(self, data, input_split= False, merge_output= True): - return self.batcher.parallelize_batches(batch_apply, data, [self.function]+self.args+self.kwargs, - input_split=input_split, merge_output=merge_output) - -# import wordbatch.batcher as batcher -# b= batcher.Batcher(minibatch_size=2)#, method="serial") -# import numpy as np -# a= ApplyBatch(b, np.power, [2],{}) -# print(a.transform([1, 2, 3, 4])) \ No newline at end of file diff --git a/wordbatch/transformers/dictionary.py b/wordbatch/transformers/dictionary.py index ab4a623..ef9928a 100644 --- a/wordbatch/transformers/dictionary.py +++ b/wordbatch/transformers/dictionary.py @@ -3,34 +3,29 @@ from __future__ import division from __future__ import absolute_import from __future__ import print_function -from collections import Counter +from collections import defaultdict import operator -WB_DOC_CNT= u'###DOC_CNT###' #Used for Spark document counting across RDFs - def batch_get_dfs(args): - dft= Counter() + dft= defaultdict(int) for text in args[0]: for word in set(text.split(" ")): dft[word]+= 1 - dft[WB_DOC_CNT]+= len(args[0]) #Avoid Spark collect() by counting here - return dft + return [dict(dft), len(args[0])] class Dictionary(object): - def __init__(self, batcher, min_df=0, max_df=1.0, max_words= 10000000000000, freeze= False, encode=True, verbose=1): + def __init__(self, min_df=0, max_df=1.0, max_words= 10000000000000, freeze= False, encode=True, verbose=0): self.verbose = verbose self.freeze = freeze self.max_words = max_words self.min_df = min_df self.max_df = max_df - self.batcher= batcher self.encode= encode self.word2id= None self.reset() - def reset(self): if self.encode: self.word2id = {} - self.dft = Counter() + self.dft = {} self.doc_count = 0 return self @@ -45,38 +40,47 @@ def get_pruning_dft(self, dft): def prune_dictionary(self, max_words=None, min_df=None, max_df=None, re_encode= False, prune_dfs= True, set_max_words= True): #Prune dictionary. Optionally prune document frequency table as well - if max_words!=None: self.max_words= max_words - if min_df!=None: self.min_df= min_df - if max_df!= None: self.max_df= max_df + if max_words != None: self.max_words= max_words + if min_df != None: self.min_df= min_df + if max_df != None: self.max_df= max_df max_words= self.max_words word2id = self.word2id dft = self.dft sorted_dft, min_df2, max_df2 = self.get_pruning_dft(dft) c= 0 - print(len(sorted_dft), len(self.word2id), len(self.raw_dft)) + #print(len(sorted_dft), len(self.word2id), len(self.raw_dft)) for word, df in sorted_dft: - if word not in word2id: - if re_encode: word2id[word]= -1 - else: continue + if word2id is not None: + if word not in word2id: + if re_encode: word2id[word]= -1 + else: continue c+= 1 if c > max_words or df < min_df2 or df > max_df2: if prune_dfs: dft.pop(word) - word2id.pop(word) + if word2id is not None: word2id.pop(word) elif re_encode: - word2id[word]= c - if set_max_words: self.max_words= len(word2id) + if word2id is not None: word2id[word]= c + if set_max_words and word2id is not None: self.max_words= len(word2id) - def fit(self, data, input_split= False, reset= False): + def fit(self, data, y=None, input_split= False, reset= False, batcher= None): if reset: self.reset() - dft= self.dft + if self.word2id is None: + self.word2id = {} word2id= self.word2id - dfts= self.batcher.parallelize_batches(batch_get_dfs, data, [], input_split= input_split, merge_output=False) - if self.batcher.spark_context is not None: dfts= [batch[1] for batch in dfts.collect()] - self.doc_count+= sum([dft2.pop(WB_DOC_CNT) for dft2 in dfts]) - for dft2 in dfts: dft.update(dft2) - - #print(dft) - if word2id!=None: + if batcher is None: dfts, doc_counts= zip(*[batch_get_dfs(data)]) + else: + # import wordbatch.pipelines + # dfts, doc_counts = zip(*batcher.collect_batches( + # wordbatch.pipelines.apply_batch.ApplyBatch(get_dfs, batcher=batcher).transform( + # data, input_split=input_split, merge_output=False) + # )) + dfts, doc_counts= zip(*batcher.collect_batches( + batcher.process_batches(batch_get_dfs, data, [], input_split= input_split, merge_output=False))) + self.doc_count += sum(doc_counts) + dft = defaultdict(int, self.dft) + for dft2 in dfts: + for k, v in dft2.items(): dft[k] += v + if word2id is not None: #Add entries. Online pruning only used to prevent inclusion into dictionary sorted_dft, min_df2, max_df2 = self.get_pruning_dft(dft) for word, df in sorted_dft: @@ -85,12 +89,19 @@ def fit(self, data, input_split= False, reset= False): if word in word2id: continue word2id[word] = len(word2id)+1 if self.verbose>2: print("Add word to dictionary:", word, dft[word], word2id[word]) + self.dft= dict(dft) return self - def fit_transform(self, data, input_split= False, merge_output= True, reset= False): - self.fit(data, input_split, reset) - return self.transform(data, input_split= input_split, merge_output= merge_output) + def partial_fit(self, data, y=None, input_split=False, batcher=None): + return self.fit(data, y, input_split, reset=False, batcher=batcher) + + def fit_transform(self, data, y=None, input_split= False, merge_output= True, reset= True, batcher= None): + self.fit(data, y=y, input_split= input_split, reset=reset, batcher=batcher) + return self.transform(data, y=y, input_split= input_split, merge_output= merge_output, batcher= None) + + def partial_fit_transform(self, data, y=None, cache_features=None, input_split=False, batcher=None): + return self.transform(data, y, cache_features, input_split, reset=False, update=True, batcher=batcher) - def transform(self, data, input_split= False, merge_output= True): - if input_split and merge_output: data= self.batcher.merge_batches(data) + def transform(self, data, y=None, input_split= False, merge_output= True, batcher= None): + if input_split and merge_output and batcher is not None: data= batcher.merge_batches(data) return data \ No newline at end of file diff --git a/wordbatch/transformers/text_normalizer.py b/wordbatch/transformers/text_normalizer.py new file mode 100644 index 0000000..605ea6a --- /dev/null +++ b/wordbatch/transformers/text_normalizer.py @@ -0,0 +1,40 @@ +#!python +from __future__ import with_statement +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +import re +import wordbatch.batcher +import wordbatch.pipelines +import wordbatch.transformers + +def batch_transform(args): + return args[1].batch_transform(args[0]) + +non_alphanums= re.compile(u'[^A-Za-z0-9]+') +def default_normalize_text(text): + return u" ".join([x for x in [y for y in non_alphanums.sub(' ', text).lower().strip().split(" ")] if len(x)>1]) + +class TextNormalizer(object): + def __init__(self, normalize_text= default_normalize_text, freeze= False, verbose= 1): + self.verbose= verbose + self.freeze= freeze + self.non_alphanums = re.compile(u'[^A-Za-z0-9]+') + self.normalize_text= normalize_text + self.reset() + + def reset(self): + return self + + def batch_transform(self, data): return [self.normalize_text(text) for text in data] + + def transform(self, data, input_split=False, merge_output=True, batcher=None): + if batcher is None: batcher = wordbatch.batcher.Batcher() + return batcher.process_batches(batch_transform, data, [self], input_split=input_split, + merge_output=merge_output) + + def fit(self, data, y=None, input_split=False, merge_output=True, batcher=None): + return self + + def fit_transform(self, data, y=None, input_split=False, merge_output=True, batcher=None): + return self.transform(data, input_split, merge_output, batcher) \ No newline at end of file diff --git a/wordbatch/transformers/tokenizer.py b/wordbatch/transformers/tokenizer.py index c79461a..faa8d3c 100644 --- a/wordbatch/transformers/tokenizer.py +++ b/wordbatch/transformers/tokenizer.py @@ -6,16 +6,12 @@ #from nltk.metrics import edit_distance import Levenshtein #python-Levenshtein from collections import defaultdict -from collections import Counter - -WB_DOC_CNT= u'###DOC_CNT###' #Used for Spark document counting across RDFs def batch_get_dfs(args): - dft= Counter() + dft= defaultdict(int) for text in args[0]: for word in set(text.split(" ")): dft[word]+= 1 - dft[WB_DOC_CNT]+= len(args[0]) #Avoid Spark collect() by counting here - return dft + return [dict(dft), len(args[0])] def correct_spelling(word, dft, spell_index, spellcor_count, spellcor_dist): #T. Bocek, E. Hunt, B. Stiller: Fast Similarity Search in Large Dictionaries, 2007 @@ -59,8 +55,7 @@ def get_deletions(word, order): return list(results.keys()) class Tokenizer(object): - def __init__(self, batcher, spellcor_count=0, spellcor_dist=2, raw_min_df= 0, stemmer= None, freeze= False, - verbose= 1): + def __init__(self, spellcor_count=0, spellcor_dist=2, raw_min_df= 0, stemmer= None, freeze= False, verbose= 0): self.verbose= verbose self.freeze= freeze if spellcor_count == 0: @@ -70,53 +65,59 @@ def __init__(self, batcher, spellcor_count=0, spellcor_dist=2, raw_min_df= 0, st self.spellcor_count = spellcor_count self.spellcor_dist = spellcor_dist self.stemmer = stemmer - self.raw_min_df = raw_min_df - self.batcher= batcher + self.raw_min_df = raw_min_df #FIX self.reset() def reset(self): - self.dft = Counter() + self.dft = {} self.doc_count = 0 return self - def fit(self, data, input_split= False, reset= True): + def fit(self, data, y= None, input_split= False, reset= True, batcher= None): if reset: self.reset() if self.freeze: return self - dft = self.dft - dfts = self.batcher.parallelize_batches(batch_get_dfs, data, [], input_split=input_split, merge_output=False) - if self.batcher.spark_context is not None: dfts = [batch[1] for batch in dfts.collect()] - self.doc_count += sum([dft2.pop(WB_DOC_CNT) for dft2 in dfts]) - for dft2 in dfts: dft.update(dft2) + if batcher is None: dfts, doc_counts= zip(*[batch_get_dfs(data)]) + else: + dfts, doc_counts= zip(*batcher.collect_batches( + batcher.process_batches(batch_get_dfs, data, [], input_split= input_split, merge_output=False))) + self.doc_count += sum(doc_counts) + dft = defaultdict(int, self.dft) + for dft2 in dfts: + for k, v in dft2.items(): dft[k] += v + self.dft= dict(dft) return self - def fit_transform(self, data, input_split= False, merge_output= True, reset= True): - self.fit(data, input_split, reset) - return self.transform(data, input_split, merge_output) + def partial_fit(self, data, y=None, input_split=False, batcher=None): + return self.fit(data, y, input_split, reset=False, batcher=batcher) + + def fit_transform(self, data, y=None, input_split= False, merge_output= True, reset= True, batcher= None): + self.fit(data, y=y, input_split= input_split, reset=reset, batcher=batcher) + return self.transform(data, y=y, input_split=input_split, merge_output=merge_output, batcher=batcher) - def transform(self, data, input_split= False, merge_output= True): + def partial_fit_transform(self, data, y=None, cache_features=None, input_split=False, batcher=None): + return self.transform(data, y, cache_features, input_split, reset=False, update=True, batcher=batcher) + + def transform(self, X, y=None, input_split= False, merge_output= True, batcher= None): if self.verbose > 0: print("Make word normalization dictionary") - if self.spellcor_dist > 0: - dft2 = {word: self.dft[word] for word in self.dft if self.dft[word] > self.spellcor_count} + do_corrections= 1 if (self.spellcor_count > 0) and (self.spellcor_dist>0) else 0 + if not(do_corrections) and self.stemmer is None: return X + if do_corrections: + dft2 = {w[0]: w[1] for w in self.dft.items() if w[1] > self.spellcor_count} spell_index = defaultdict(list) for word in dft2: if len(word) > 15: continue for word2 in get_deletions(word, self.spellcor_dist): spell_index[word2].append(word) - if self.stemmer != None: - if self.spellcor_count > 0: + if self.stemmer is not None: + if do_corrections: corrs = {word: self.stemmer.stem(correct_spelling( word, dft2, spell_index, self.spellcor_count, self.spellcor_dist)) for word in self.dft} else: corrs = {word: self.stemmer.stem(word) for word in self.dft} - else: + elif do_corrections: corrs = {word: correct_spelling( word, dft2, spell_index, self.spellcor_count, self.spellcor_dist) for word in self.dft} corrs = {key: value for key, value in corrs.items() if key != value} if self.verbose > 0: print("Make word normalizations") - return self.batcher.parallelize_batches(batch_correct_spellings, data, [corrs], - input_split=input_split, merge_output=merge_output) -# import wordbatch.batcher as batcher -# b= batcher.Batcher(method="serial") -# t= [[1, 2], [3, 4]] -# import numpy as np -# a= Apply(b, np.power, [2],{}) -# print(a.transform(t)) \ No newline at end of file + if batcher is None: return batch_correct_spellings(X) + return batcher.process_batches(batch_correct_spellings, X, [corrs], + input_split=input_split, merge_output=merge_output) diff --git a/wordbatch/wordbatch.py b/wordbatch/wordbatch.py deleted file mode 100644 index 4362de8..0000000 --- a/wordbatch/wordbatch.py +++ /dev/null @@ -1,104 +0,0 @@ -#!python -from __future__ import with_statement -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function -import multiprocessing -import re -import os -import wordbatch.batcher as batcher - -non_alphanums= re.compile(u'[^A-Za-z0-9]+') -def default_normalize_text(text): - return u" ".join([x for x in [y for y in non_alphanums.sub(' ', text).lower().strip().split(" ")] if len(x)>1]) - -class WordBatch(object): - def __init__(self, normalize_text= default_normalize_text, max_words= 10000000, min_df= 0, max_df= 1.0, - spellcor_count=0, spellcor_dist=2, raw_min_df= -1, stemmer= None, extractor=None, - procs=0, minibatch_size= 20000, timeout= 600, spark_context= None, freeze= False, - method= "multiprocessing", verbose= 1): - if procs==0: procs= multiprocessing.cpu_count() - self.verbose= verbose - self.batcher= batcher.Batcher(procs=procs, minibatch_size=minibatch_size, timeout=timeout, - spark_context=spark_context, method=method, verbose=verbose) - - import wordbatch.transformers.apply as apply - if normalize_text is None: self.normalize_text= None - else: self.normalize_text= apply.Apply(self.batcher, normalize_text) - - import wordbatch.transformers.dictionary as dictionary - self.dictionary= dictionary.Dictionary(self.batcher, min_df=min_df, max_df=max_df, max_words= max_words, - freeze= False, verbose=verbose) - - import wordbatch.transformers.tokenizer as tokenizer - if spellcor_count>0 or stemmer!=None: - self.tokenizer= tokenizer.Tokenizer(self.batcher, spellcor_count, spellcor_dist, raw_min_df, stemmer, - verbose= verbose) - else: self.tokenizer= None - self.set_extractor(extractor) - self.freeze= freeze - - def reset(self): - self.dictionary.reset() - return self - - def set_extractor(self, extractor=None): - if extractor is not None: - if type(extractor) != tuple and type(extractor) != list: - self.extractor = extractor(self.batcher, self.dictionary, {}) - else: self.extractor = extractor[0](self.batcher, self.dictionary, extractor[1]) - else: self.extractor = None - - def process(self, texts, input_split= False, reset= True, update= True): - if reset: self.reset() - if self.freeze: update= False - - if self.normalize_text is not None: - if self.verbose > 0: print("Normalize text") - texts= self.normalize_text.transform(texts, input_split= input_split, merge_output= False) - input_split= True - - if self.tokenizer is not None: - if self.verbose > 0: print("Tokenize text") - if update: texts= self.tokenizer.fit_transform(texts, input_split= input_split, merge_output= False, - reset= reset) - else: texts= self.tokenizer.transform(texts, input_split= input_split, merge_output= False) - input_split= True - - if self.dictionary is not None: - if update: - texts= self.dictionary.fit_transform(texts, input_split=input_split, merge_output=False, reset=reset) - if self.verbose> 2: print("len(self.dictionary.dft):", len(self.dictionary.dft)) - return texts - - def fit(self, texts, input_split= False, reset= True): - self.process(texts, input_split, reset=reset, update= True) - return self - - def transform(self, texts, extractor= None, cache_features= None, input_split= False, reset= False, update= False): - if extractor== None: extractor= self.extractor - if cache_features != None and os.path.exists(cache_features): return extractor.load_features(cache_features) - if not(input_split): texts= self.batcher.split_batches(texts) - - texts= self.process(texts, input_split=True, reset=reset, update= update) - if extractor!= None: - texts= extractor.transform(texts, input_split= True, merge_output= True) - if cache_features!=None: extractor.save_features(cache_features, texts) - return texts - else: - return self.batcher.merge_batches(texts) - - def partial_fit(self, texts, input_split=False): - return self.fit(texts, input_split, reset=False) - - def fit_transform(self, texts, extractor=None, cache_features=None, input_split=False, reset=True): - return self.transform(texts, extractor, cache_features, input_split, reset, update=True) - - def partial_fit_transform(self, texts, extractor=None, cache_features=None, input_split=False): - return self.transform(texts, extractor, cache_features, input_split, reset=False, update=True) - - def __getstate__(self): - return dict((k, v) for (k, v) in self.__dict__.items()) - - def __setstate__(self, params): - for key in params: setattr(self, key, params[key])