首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加速Keras模型预测负荷时间

加速Keras模型预测负荷时间
EN

Data Science用户
提问于 2021-01-05 16:00:58
回答 2查看 1.3K关注 0票数 3

我正在尝试使用keras创建一个预测API,该API加载模型、预测和关闭模型。但是python中的初始化时间大约为3-5秒,因此每个请求都需要大约5秒才能返回预测,而不管输入的行数(预测)如何。

是否有任何方法保持模型加载,然后流输入数据,以获得预测。就像预装的模型,不管是通过套接字还是通过端口。

类似于开放的office文档转换器

代码语言:javascript
复制
\program\soffice.exe -accept="socket,host=127.0.0.1,port=8100;urp;" -headless -nofirststartwizard -nologo

Keras预测码

代码语言:javascript
复制
#!/usr/bin/env python3.6
import sys
import pandas as pd
from keras.models import load_model
model = load_model('model.h5')
X = pd.read_csv(sys.argv[1]).values
prediction = model.predict(X)
pd.DataFrame(prediction).to_json(sys.argv[2])

脚本被称为

代码语言:javascript
复制
python3.6 predict.py input_scaled.csv output_scaled.json

预测时间如下

代码语言:javascript
复制
#row    time
1       4.76 secs
10      4.49 secs
50      5.37 secs
5000    5.46 secs
50000   12.7 secs
EN

回答 2

Data Science用户

发布于 2021-01-06 10:38:35

不用烧瓶和django我就能像这样工作了。只是在python中使用默认的http.server

代码语言:javascript
复制
from http.server import BaseHTTPRequestHandler, HTTPServer
import logging
import sys
import pandas as pd
from keras.models import load_model
from urllib.parse import urlparse
model = load_model('model.h5')

class S(BaseHTTPRequestHandler):
    def _set_response(self):
        self.send_response(200)
        self.send_header('Content-type', 'text/html')
        self.end_headers()

    def do_GET(self):
        query = urlparse(self.path).query
        params = dict(qc.split("=") for qc in query.split("&"))
        X = pd.read_csv(params["input"]).values
        prediction = model.predict(X)
        pd.DataFrame(prediction).to_json(params["output"])
        self._set_response()
        self.wfile.write("Processed".encode('utf-8'))

def run(server_class=HTTPServer, handler_class=S, port=8080):
    logging.basicConfig(level=logging.INFO)
    server_address = ('', port)
    httpd = server_class(server_address, handler_class)
    logging.info('Starting httpd...\n')
    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    httpd.server_close()
    logging.info('Stopping httpd...\n')

if __name__ == '__main__':
    from sys import argv

    if len(argv) == 2:
        run(port=int(argv[1]))
    else:
        run()

触发器服务器使用

代码语言:javascript
复制
python3.6 predict_server.py 8000

类API

代码语言:javascript
复制
http://ip/localhost:8000/?input=predict_scaled.csv&output=prediction.json
票数 1
EN

Data Science用户

发布于 2021-01-06 09:20:59

我能想到的最简单的方法是创建一个烧瓶应用程序,它将加载模型一次,并且有一个端点,您可以将数据作为请求发送到已经加载的模型。

服务的大致框架如下所示:

代码语言:javascript
复制
from flask import Flask, request


app = Flask(__name__)


@app.route('/')
def index():
    return ''

@app.route('/predict/', methods=['GET', 'POST'])
def predict():
    X = pd.read_csv(request.get_data()).values
    prediction = model.predict(X)
    return pd.DataFrame(prediction).to_json()


if __name__ == "__main__":
    model = load_model('model.h5')
    app.run()

然后,您可以通过另一个脚本向localhost:5000/predict发出HTTP请求,它将返回您的预测,然后您可以保存或执行您想要的任何操作。

票数 0
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/87545

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档