码迷,mamicode.com
首页 > 编程语言 > 详细

多线程提速

时间:2020-10-22 22:46:19      阅读:27      评论:0      收藏:0      [点我收藏+]

标签:import   lines   join()   inf   jieba   odi   词向量   targe   segment   

对于请求反馈使用线程来提速

"""
    Function: get similarity query
    Author: dengyx
    DateTime: 20201019
"""
import jieba
import time
import tqdm
import threading
import queue
import numpy as np
from gensim.models import KeyedVectors
import logging
logging.basicConfig(format=%(asctime)s : %(levelname)s : %(message)s, level=logging.INFO)
from utils.soaring_vector.soaring_vector.soaring_vector import SoaringVectorClient, IndexType, FieldType, IndexData, SearchQuery, Vector, FloatVector

client = SoaringVectorClient("172.16.24.150", 8098, 1000)
print("health : ", client.health())

index_name = "seo-query-v10dim"
if client.exist(index_name):
    api_index = client.get(index_name)
    print(index_name + " is exist")
else:
    schema = {query: FieldType.STRING_TYPE, id: FieldType.STRING_TYPE}
    api_index = client.create(index_name, "query search match", IndexType.FloatFlatIP, query, schema, 10, thread=12)
    client.set_alias(index_name, "seo-phrase-match")

print(api_index.info)


class QuerySimilarity(object):
    def __init__(self,):
        # self.query_path = r‘data/test.txt‘
        self.query_path = rdata/seo_search_word_copy.txt
        self.w2c_path = rresources_10dim/word2vec.model
        self.query_features = rresources/features.pkl
        self.tables = rresources/hashtables.pkl
        self.table_num = 3
        self.Hashcode_fun = 6
        self.query2id = {}
        self.thread_num = 8

        print(加载词向量...)
        t1 = time.time()
        self.model = KeyedVectors.load(self.w2c_path, mmap=r)
        t2 = time.time()
        print(词向量加载时间:{:.2f}s.format(t2-t1))
        with open(self.query_path, r, encoding=utf8) as fr:
            self.content = fr.readlines()

        for each in self.content:
            item = each.strip().split(\t)
            query_id = item[0]
            query = item[-1]
            self.query2id[query] = query_id

    def cosine_sim(self, x, y):
        num = x.dot(y.T)
        denom = np.linalg.norm(x) * np.linalg.norm(y)
        return num / denom

    def feature_extract(self, query):
        """ word -> feature
        :param query:
        :return:
        """
        vec = []
        tokens = jieba.lcut(query)
        for word in tokens:
            if word in self.model:
                vec.append(self.model[word])
            else:
                vec.append([0]*10)
                # print(‘{}\n{}\n{} not in word2vec‘.format(query, tokens, word))
        vec = np.array(vec)
        mean_vec = np.mean(vec, axis=0)
        if len(mean_vec) != 10:
            print(向量纬度不是100)
        return mean_vec

    def upload_data(self):
        """ clean segment stopwords
        :return:
        """
        self.counter = 0
        # self.query2id = {}
        data_map_buffer = dict()
        for each in self.content:
            item = each.strip().split(\t)
            query_id = item[0]
            query = item[-1]
            # self.query2id[query] = query_id
            current_feature = self.feature_extract(query)
            vector = self.l2_norm(current_feature).tolist()
            data = {query: query, id: query_id}
            data_map_buffer[query] = IndexData(data, vector)
            if len(data_map_buffer) > 1000:
                api_index.put(data_map_buffer)
                self.counter += len(data_map_buffer)
                data_map_buffer = dict()
                logging.info(put  + str(self.counter))
        if len(data_map_buffer) > 0:
            api_index.put(data_map_buffer)
            self.counter += len(data_map_buffer)
            logging.info(put  + str(self.counter))
            data_map_buffer = dict()
        print(数据上传完成)

    def l2_norm(self, m):
        dist = np.sqrt((m ** 2).sum(-1))[..., np.newaxis]
        m /= dist
        return m

    def download(self):
        with open(self.query_path, r, encoding=utf8) as fr:
            content = fr.readlines()
            new_content = []
            for each in tqdm.tqdm(content):
                each_item = each.strip().split(\t)
                phrase = each_item[-1]

                api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values
                query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector)))
                res = api_index.search(query, 0, 40)
                line = ‘‘
                for ret in res.result:
                    items = sorted(ret.item, key=lambda v: v.score, reverse=True)
                    for item in items[1:31]:
                        line += self.query2id[item.key] + 
                to_save = each.strip() + \t + line[:-1] + \n
                new_content.append(to_save)

        save_path = rdata/query_top30_20201021.txt
        with open(save_path, w, encoding=utf8) as fw:
            fw.writelines(new_content)
        print(数据保存成功:{}.format(save_path))

    def run(self, q, fw):
        while True:
            if q.empty():
                return
            else:
                sample = q.get()
                each_item = sample.strip().split(\t)
                phrase = each_item[-1]
                api_vector = dict(api_index.get(phrase).data.vector.vector).get(phrase).floatVector.values
                query = SearchQuery(vector=Vector(floatVector=FloatVector(values=api_vector)))
                res = api_index.search(query, 0, 40)
                line = ‘‘
                # result = []
                for ret in res.result:
                    items = sorted(ret.item, key=lambda v: v.score, reverse=True)
                    for item in items[1:31]:
                        line += self.query2id[item.key] + 
                        # result.append(item.key)
                to_save = sample.strip() + \t + line[:-1] + \n
                # print(result)
                # print(to_save)
                print(each_item[0])
                fw.write(to_save)

    def main(self, data_path):
        q = queue.Queue()
        save_path = rdata/query_top30_20201022.txt
        fw = open(save_path, a, encoding=utf8)

        # split_num = 250000
        # with open(self.query_path, ‘r‘, encoding=‘utf8‘) as fr:
        #     content = fr.readlines()
        #     for i in range(0, len(content), split_num):
        #         split_data = content[i:i+split_num]
        #         with open(‘data/split_data/group_{}.txt‘.format(i), ‘w‘, encoding=‘utf8‘) as fw:
        #             fw.writelines(split_data)

        with open(data_path, r, encoding=utf8) as fr:
            content = fr.readlines()
            for d in tqdm.tqdm(content):
                q.put(d)
            print(数据放入队列完毕)
        t1 = time.time()
        threads = []
        print(数据预测中...)
        for i in range(self.thread_num):
            t = threading.Thread(target=self.run, args=(q, fw))
            threads.append(t)
        for i in range(self.thread_num):
            threads[i].start()
        for i in range(self.thread_num):
            threads[i].join()
        t2 = time.time()
        print(处理速度:{:.4f}sample/s.format(len(content)/(t2-t1)))
        print(数据写入完毕)


# Press the green button in the gutter to run the script.
if __name__ == __main__:
    data_path = rdata/seo_search_word_copy.txt
    qs = QuerySimilarity()
    qs.main(data_path)
    # qs.upload_data()

 

多线程提速

标签:import   lines   join()   inf   jieba   odi   词向量   targe   segment   

原文地址:https://www.cnblogs.com/demo-deng/p/13857560.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!