diff --git a/pipelines/examples/text_to_image/README.md b/pipelines/examples/text_to_image/README.md
new file mode 100644
index 000000000000..23512c40ef23
--- /dev/null
+++ b/pipelines/examples/text_to_image/README.md
@@ -0,0 +1,114 @@
+# ERNIE-ViLG 文生图系统
+
+## 1. 场景概述
+
+ERNIE-ViLG是一个知识增强跨模态图文生成大模型,将文生成图和图生成文任务融合到同一个模型进行端到端的学习,从而实现文本和图像的跨模态语义对齐。可以支持用户进行内容创作,让每个用户都能够体验到一个低门槛的创作平台。更多详细信息请参考官网的介绍[ernieVilg](https://wenxin.baidu.com/moduleApi/ernieVilg)
+
+
+## 2. 产品功能介绍
+
+本项目提供了低成本搭建端到端文生图的能力。用户需要进行简单的参数配置,然后输入prompts就可以生成各种风格的画作,另外,Pipelines提供了 Web 化产品服务,让用户在本地端就能搭建起来文生图系统。
+
+
+## 3. 快速开始: 快速搭建文生图系统
+
+
+### 3.1 运行环境和安装说明
+
+本实验采用了以下的运行环境进行,详细说明如下,用户也可以在自己的环境进行:
+
+a. 软件环境:
+- python >= 3.7.0
+- paddlenlp >= 2.4.0
+- paddlepaddle-gpu >=2.3
+- CUDA Version: 10.2
+- NVIDIA Driver Version: 440.64.00
+- Ubuntu 16.04.6 LTS (Docker)
+
+b. 硬件环境:
+
+- NVIDIA Tesla V100 16GB x4卡
+- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
+
+c. 依赖安装:
+首先需要安装PaddlePaddle,PaddlePaddle的安装请参考文档[官方安装文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),然后安装下面的依赖:
+```bash
+# pip 一键安装
+pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple
+# 或者源码进行安装最新版本
+cd ${HOME}/PaddleNLP/pipelines/
+pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+python setup.py install
+```
+【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录;另外,文生图系统需要联网,用户需要在有网的环境下进行。
+
+
+### 3.2 一键体验文生图系统
+
+在运行下面的命令之前,需要在[ERNIE-ViLG官网](https://wenxin.baidu.com/moduleApi/ernieVilg)申请`API Key`和 `Secret key`两个密钥(需要登录,登录后点击右上角的查看AK/SK,具体如下图),然后执行下面的命令。
+
+
+
+
+
+
+#### 3.2.1 快速一键启动
+
+您可以通过如下命令快速体验文生图系统的效果
+```bash
+python examples/text_to_image/text_to_image_example.py --prompt_text 宁静的小镇 \
+ --style 古风 \
+ --topk 5 \
+ --api_key 你申请的apikey \
+ --secret_key 你申请的secretkey \
+ --output_dir ernievilg_output
+```
+大概运行一分钟后就可以得到结果了,生成的图片请查看您的输出目录`output_dir`。
+
+### 3.3 构建 Web 可视化文生图系统
+
+整个 Web 可视化文生图系统主要包含 2 大组件: 1. 基于 RestfulAPI 构建模型服务 2. 基于 Gradio 构建 WebUI,接下来我们依次搭建这 2 个服务并最终形成可视化的文生图系统。
+
+#### 3.3.1 启动 RestAPI 模型服务
+
+启动之前,需要把您申请的`API Key`和 `Secret key`两个密钥添加到`text_to_image.yaml`的ak和sk的位置,然后运行:
+
+```bash
+export PIPELINE_YAML_PATH=rest_api/pipeline/text_to_image.yaml
+# 使用端口号 8891 启动模型服务
+python rest_api/application.py 8891
+```
+Linux 用户推荐采用 Shell 脚本来启动服务::
+
+```bash
+sh examples/text_to_image/run_text_to_image.sh
+```
+
+#### 3.3.2 启动 WebUI
+
+WebUI使用了[gradio前端](https://gradio.app/),首先需要安装gradio,运行命令如下:
+```
+pip install gradio
+```
+然后使用如下的命令启动:
+```bash
+# 配置模型服务地址
+export API_ENDPOINT=http://127.0.0.1:8891
+# 在指定端口 8502 启动 WebUI
+python ui/webapp_text_to_image.py --serving_port 8502
+```
+Linux 用户推荐采用 Shell 脚本来启动服务::
+
+```bash
+sh examples/text_to_image/run_text_to_image_web.sh
+```
+
+到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验文生图系统服务了。
+
+如果安装遇见问题可以查看[FAQ文档](../../FAQ.md)
+
+## Acknowledge
+
+我们借鉴了 Deepset.ai [Haystack](https://github.com/deepset-ai/haystack) 优秀的框架设计,在此对[Haystack](https://github.com/deepset-ai/haystack)作者及其开源社区表示感谢。
+
+We learn form the excellent framework design of Deepset.ai [Haystack](https://github.com/deepset-ai/haystack), and we would like to express our thanks to the authors of Haystack and their open source community.
diff --git a/pipelines/examples/text_to_image/run_text_to_image.sh b/pipelines/examples/text_to_image/run_text_to_image.sh
new file mode 100644
index 000000000000..4a61f0b98e9e
--- /dev/null
+++ b/pipelines/examples/text_to_image/run_text_to_image.sh
@@ -0,0 +1,19 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 指定文生图的Yaml配置文件
+unset http_proxy && unset https_proxy
+export PIPELINE_YAML_PATH=rest_api/pipeline/text_to_image.yaml
+# 使用端口号 8891 启动模型服务
+python rest_api/application.py 8891
\ No newline at end of file
diff --git a/pipelines/examples/text_to_image/run_text_to_image_web.sh b/pipelines/examples/text_to_image/run_text_to_image_web.sh
new file mode 100644
index 000000000000..05a59f7be69f
--- /dev/null
+++ b/pipelines/examples/text_to_image/run_text_to_image_web.sh
@@ -0,0 +1,18 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 配置模型服务地址
+export API_ENDPOINT=http://127.0.0.1:8891
+# 在指定端口 8502 启动 WebUI
+python ui/webapp_text_to_image.py --serving_port 8502
\ No newline at end of file
diff --git a/pipelines/examples/text_to_image/text_to_image_example.py b/pipelines/examples/text_to_image/text_to_image_example.py
new file mode 100644
index 000000000000..8637b1a52aa4
--- /dev/null
+++ b/pipelines/examples/text_to_image/text_to_image_example.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+
+import paddle
+from pipelines.nodes import ErnieTextToImageGenerator
+from pipelines import TextToImagePipeline
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
+parser.add_argument("--secret_key", default=None, type=str, help="The secret key.")
+parser.add_argument("--prompt_text", default='宁静的小镇', type=str, help="The prompt_text.")
+parser.add_argument("--output_dir", default='ernievilg_output', type=str, help="The output path.")
+parser.add_argument("--style", default='探索无限', type=str, help="The style text.")
+parser.add_argument("--size", default='1024*1024',
+ choices=['1024*1024', '1024*1536', '1536*1024'], help="Size of the generation images")
+parser.add_argument("--topk", default=5, type=int, help="The top k images.")
+args = parser.parse_args()
+# yapf: enable
+
+
+def text_to_image():
+ erine_image_generator = ErnieTextToImageGenerator(ak=args.api_key,
+ sk=args.secret_key)
+ pipe = TextToImagePipeline(erine_image_generator)
+ prediction = pipe.run(query=args.prompt_text,
+ params={
+ "TextToImageGenerator": {
+ "topk": args.topk,
+ "style": args.style,
+ "resolution": args.size,
+ "output_dir": args.output_dir
+ }
+ })
+ pipe.save_to_yaml('text_to_image.yaml')
+
+
+if __name__ == "__main__":
+ text_to_image()
diff --git a/pipelines/pipelines/__init__.py b/pipelines/pipelines/__init__.py
index 83dc75fcf6b2..ede1272c1316 100644
--- a/pipelines/pipelines/__init__.py
+++ b/pipelines/pipelines/__init__.py
@@ -39,7 +39,8 @@
from pipelines.pipelines import Pipeline
from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline,
ExtractiveQAPipeline,
- SemanticSearchPipeline)
+ SemanticSearchPipeline,
+ TextToImagePipeline)
import pandas as pd
diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py
index a4285acaaf47..a56fd1ccbcfd 100644
--- a/pipelines/pipelines/nodes/__init__.py
+++ b/pipelines/pipelines/nodes/__init__.py
@@ -29,3 +29,4 @@
from pipelines.nodes.ranker import BaseRanker, ErnieRanker
from pipelines.nodes.reader import BaseReader, ErnieReader
from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever
+from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator
diff --git a/pipelines/pipelines/nodes/text_to_image_generator/__init__.py b/pipelines/pipelines/nodes/text_to_image_generator/__init__.py
new file mode 100644
index 000000000000..579c485c01a0
--- /dev/null
+++ b/pipelines/pipelines/nodes/text_to_image_generator/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pipelines.nodes.text_to_image_generator.text_to_image_generator import ErnieTextToImageGenerator
diff --git a/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py b/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py
new file mode 100644
index 000000000000..9a7bf2aa7389
--- /dev/null
+++ b/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import requests
+import hashlib
+from io import BytesIO
+from PIL import Image
+from typing import List
+from typing import Optional
+from tqdm.auto import tqdm
+
+from pipelines.schema import Document
+from pipelines.nodes.base import BaseComponent
+
+
+class ErnieTextToImageGenerator(BaseComponent):
+ """
+ ErnieTextToImageGenerator that uses a Ernie Vilg for text to image generation.
+ """
+
+ def __init__(self, ak=None, sk=None):
+ """
+ :param ak: ak for applying token to request wenxin api.
+ :param sk: sk for applying token to request wenxin api.
+ """
+ if (ak is None or sk is None):
+ raise Exception(
+ "Please apply api_key and secret_key from https://wenxin.baidu.com/moduleApi/ernieVilg"
+ )
+ self.ak = ak
+ self.sk = sk
+ self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
+ self.token = self._apply_token(self.ak, self.sk)
+
+ # save init parameters to enable export of component config as YAML
+ self.set_config(
+ ak=ak,
+ sk=sk,
+ )
+
+ def _apply_token(self, ak, sk):
+ if ak is None or sk is None:
+ ak = self.ak
+ sk = self.sk
+ response = requests.get(self.token_host,
+ params={
+ 'grant_type': 'client_credentials',
+ 'client_id': ak,
+ 'client_secret': sk
+ })
+ if response:
+ res = response.json()
+ if res['code'] != 0:
+ print('Request access token error.')
+ raise RuntimeError("Request access token error.")
+ else:
+ print('Request access token error.')
+ raise RuntimeError("Request access token error.")
+ return res['data']
+
+ def generate_image(self,
+ text_prompts,
+ style: Optional[str] = "探索无限",
+ resolution: Optional[str] = "1024*1024",
+ topk: Optional[int] = 6,
+ visualization: Optional[bool] = True,
+ output_dir: Optional[str] = 'ernievilg_output'):
+ """
+ Create image by text prompts using ErnieVilG model.
+ :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
+ :param style: Image stype, currently supported 古风、油画、水彩、卡通、二次元、浮世绘、蒸汽波艺术、
+ low poly、像素风格、概念艺术、未来主义、赛博朋克、写实风格、洛丽塔风格、巴洛克风格、超现实主义、探索无限。
+ :param resolution: Resolution of images, currently supported "1024*1024", "1024*1536", "1536*1024".
+ :param topk: Top k images to save.
+ :param visualization: Whether to save images or not.
+ :output_dir: Output directory
+ """
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ token = self.token
+ create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
+ get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
+ if isinstance(text_prompts, str):
+ text_prompts = [text_prompts]
+ taskids = []
+ for text_prompt in text_prompts:
+ res = requests.post(
+ create_url,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
+ data={
+ 'access_token': token,
+ "text": text_prompt,
+ "style": style,
+ "resolution": resolution
+ })
+ res = res.json()
+ if res['code'] == 4001:
+ print('请求参数错误')
+ raise RuntimeError("请求参数错误")
+ elif res['code'] == 4002:
+ print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
+ raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
+ elif res['code'] == 4003:
+ print('请求参数中,图片风格不在可选范围内')
+ raise RuntimeError("请求参数中,图片风格不在可选范围内")
+ elif res['code'] == 4004:
+ print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
+ raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
+ elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
+ token = self._apply_token(self.ak, self.sk)
+ res = requests.post(create_url,
+ headers={
+ 'Content-Type':
+ 'application/x-www-form-urlencoded'
+ },
+ data={
+ 'access_token': token,
+ "text": text_prompt,
+ "style": style,
+ "resolution": resolution
+ })
+ res = res.json()
+ if res['code'] != 0:
+ print("Token失效重新请求后依然发生错误,请检查输入的参数")
+ raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
+ if res['msg'] == 'success':
+ taskids.append(res['data']["taskId"])
+ else:
+ print(res['msg'])
+ raise RuntimeError(res['msg'])
+
+ start_time = time.time()
+ process_bar = tqdm(total=100, unit='%')
+ results = {}
+ total_time = 60 * len(taskids)
+ while True:
+ end_time = time.time()
+ duration = end_time - start_time
+ progress_rate = int((duration) / total_time * 100)
+ if not taskids:
+ progress_rate = 100
+ if progress_rate > process_bar.n:
+ if progress_rate >= 100:
+ if not taskids:
+ increase_rate = 100 - process_bar.n
+ else:
+ increase_rate = 0
+ else:
+ increase_rate = progress_rate - process_bar.n
+ else:
+ increase_rate = 0
+ process_bar.update(increase_rate)
+ if duration < 30:
+ time.sleep(5)
+ continue
+ else:
+ time.sleep(6)
+ if not taskids:
+ break
+ has_done = []
+ for taskid in taskids:
+ res = requests.post(get_url,
+ headers={
+ 'Content-Type':
+ 'application/x-www-form-urlencoded'
+ },
+ data={
+ 'access_token': token,
+ 'taskId': {taskid}
+ })
+ res = res.json()
+ if res['code'] == 4001:
+ print('请求参数错误')
+ raise RuntimeError("请求参数错误")
+ elif res['code'] == 4002:
+ print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
+ raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
+ elif res['code'] == 4003:
+ print('请求参数中,图片风格不在可选范围内')
+ raise RuntimeError("请求参数中,图片风格不在可选范围内")
+ elif res['code'] == 4004:
+ print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
+ raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
+ elif res['code'] == 100 or res['code'] == 110 or res[
+ 'code'] == 111:
+ token = self._apply_token(self.ak, self.sk)
+ res = requests.post(get_url,
+ headers={
+ 'Content-Type':
+ 'application/x-www-form-urlencoded'
+ },
+ data={
+ 'access_token': token,
+ 'taskId': {taskid}
+ })
+ res = res.json()
+ if res['code'] != 0:
+ print("Token失效重新请求后依然发生错误,请检查输入的参数")
+ raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
+ if res['msg'] == 'success':
+ if res['data']['status'] == 1:
+ has_done.append(res['data']['taskId'])
+ results[res['data']['text']] = {
+ 'imgUrls': res['data']['imgUrls'],
+ 'waiting': res['data']['waiting'],
+ 'taskId': res['data']['taskId']
+ }
+ else:
+ print(res['msg'])
+ raise RuntimeError(res['msg'])
+ for taskid in has_done:
+ taskids.remove(taskid)
+ print('Saving Images...')
+ result_images = []
+ for text, data in results.items():
+ for idx, imgdata in enumerate(data['imgUrls']):
+ try:
+ image = Image.open(
+ BytesIO(requests.get(imgdata['image']).content))
+ except Exception as e:
+ print('Download generated images error, retry one time')
+ try:
+ image = Image.open(
+ BytesIO(requests.get(imgdata['image']).content))
+ except Exception:
+ raise RuntimeError('Download generated images failed.')
+ if visualization:
+ ext = 'png'
+ md5hash = hashlib.md5(image.tobytes())
+ md5_name = md5hash.hexdigest()
+ image_name = '{}.{}'.format(md5_name, ext)
+ image_path = os.path.join(output_dir, image_name)
+ image.save(image_path)
+ result_images.append(image_path)
+ if idx + 1 >= topk:
+ break
+ print('Done')
+ return result_images
+
+ def run(self,
+ query: Document,
+ style: Optional[str] = None,
+ topk: Optional[int] = None,
+ resolution: Optional[str] = "1024*1024",
+ output_dir: Optional[str] = 'ernievilg_output'):
+
+ result_images = self.generate_image(query,
+ style=style,
+ topk=topk,
+ resolution=resolution,
+ output_dir=output_dir)
+ results = {"results": result_images}
+ return results, "output_1"
diff --git a/pipelines/pipelines/pipelines/__init__.py b/pipelines/pipelines/pipelines/__init__.py
index a1a25f53aa4b..04f1367033a6 100644
--- a/pipelines/pipelines/pipelines/__init__.py
+++ b/pipelines/pipelines/pipelines/__init__.py
@@ -15,4 +15,5 @@
from pipelines.pipelines.base import Pipeline, RootNode
from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline,
ExtractiveQAPipeline,
- SemanticSearchPipeline)
+ SemanticSearchPipeline,
+ TextToImagePipeline)
diff --git a/pipelines/pipelines/pipelines/standard_pipelines.py b/pipelines/pipelines/pipelines/standard_pipelines.py
index d459c33db7c2..2b2fc5cbe769 100644
--- a/pipelines/pipelines/pipelines/standard_pipelines.py
+++ b/pipelines/pipelines/pipelines/standard_pipelines.py
@@ -24,6 +24,7 @@
from pipelines.nodes.ranker import BaseRanker
from pipelines.nodes.retriever import BaseRetriever
from pipelines.document_stores import BaseDocumentStore
+from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator
from pipelines.pipelines import Pipeline
logger = logging.getLogger(__name__)
@@ -263,3 +264,34 @@ def run(self,
"""
output = self.pipeline.run(query=query, params=params, debug=debug)
return output
+
+
+class TextToImagePipeline(BaseStandardPipeline):
+ """
+ A simple pipeline that takes prompt texts as input and generates
+ images.
+ """
+
+ def __init__(self, text_to_image_generator: ErnieTextToImageGenerator):
+ self.pipeline = Pipeline()
+ self.pipeline.add_node(component=text_to_image_generator,
+ name="TextToImageGenerator",
+ inputs=["Query"])
+
+ def run(self,
+ query: str,
+ params: Optional[dict] = None,
+ debug: Optional[bool] = None):
+ output = self.pipeline.run(query=query, params=params, debug=debug)
+ return output
+
+ def run_batch(
+ self,
+ documents: List[Document],
+ params: Optional[dict] = None,
+ debug: Optional[bool] = None,
+ ):
+ output = self.pipeline.run_batch(documents=documents,
+ params=params,
+ debug=debug)
+ return output
diff --git a/pipelines/rest_api/controller/search.py b/pipelines/rest_api/controller/search.py
index 29c7359e608b..780137440225 100644
--- a/pipelines/rest_api/controller/search.py
+++ b/pipelines/rest_api/controller/search.py
@@ -27,7 +27,7 @@
from pipelines.pipelines.base import Pipeline
from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME
from rest_api.config import LOG_LEVEL, CONCURRENT_REQUEST_PER_WORKER
-from rest_api.schema import QueryRequest, QueryResponse
+from rest_api.schema import QueryRequest, QueryResponse, QueryImageResponse
from rest_api.controller.utils import RequestLimiter
logging.getLogger("pipelines").setLevel(LOG_LEVEL)
@@ -81,6 +81,27 @@ def query(request: QueryRequest):
return result
+@router.post("/query_text_to_images",
+ response_model=QueryImageResponse,
+ response_model_exclude_none=True)
+def query_images(request: QueryRequest):
+ """
+ This endpoint receives the question as a string and allows the requester to set
+ additional parameters that will be passed on to the pipelines pipeline.
+ """
+ result = {}
+ result['query'] = request.query
+ params = request.params or {}
+ res = PIPELINE.run(query=request.query, params=params, debug=request.debug)
+ # Ensure answers and documents exist, even if they're empty lists
+ result['answers'] = res['results']
+ if not "documents" in result:
+ result["documents"] = []
+ if not "answers" in result:
+ result["answers"] = []
+ return result
+
+
def _process_request(pipeline, request) -> Dict[str, Any]:
start_time = time.time()
diff --git a/pipelines/rest_api/pipeline/text_to_image.yaml b/pipelines/rest_api/pipeline/text_to_image.yaml
new file mode 100644
index 000000000000..959781b2f225
--- /dev/null
+++ b/pipelines/rest_api/pipeline/text_to_image.yaml
@@ -0,0 +1,16 @@
+version: '1.1.0'
+
+components:
+ - name: TextToImageGenerator
+ params:
+ ak:
+ sk:
+ type: ErnieTextToImageGenerator
+pipelines:
+ - name: query
+ type: Query
+ nodes:
+ - name: TextToImageGenerator
+ inputs: [Query]
+
+
diff --git a/pipelines/rest_api/schema.py b/pipelines/rest_api/schema.py
index 942a4e7029ef..e041d2bad62e 100644
--- a/pipelines/rest_api/schema.py
+++ b/pipelines/rest_api/schema.py
@@ -83,3 +83,10 @@ class QueryResponse(BaseModel):
answers: List[AnswerSerialized] = []
documents: List[DocumentSerialized] = []
debug: Optional[Dict] = Field(None, alias="_debug")
+
+
+class QueryImageResponse(BaseModel):
+ query: str
+ answers: List[str] = []
+ documents: List[DocumentSerialized] = []
+ debug: Optional[Dict] = Field(None, alias="_debug")
diff --git a/pipelines/ui/utils.py b/pipelines/ui/utils.py
index 1d672613dbb1..540c44cc2247 100644
--- a/pipelines/ui/utils.py
+++ b/pipelines/ui/utils.py
@@ -31,6 +31,7 @@
DOC_FEEDBACK = "feedback"
DOC_UPLOAD = "file-upload"
DOC_PARSE = 'files'
+IMAGE_REQUEST = 'query_text_to_images'
def pipelines_is_ready():
@@ -184,6 +185,35 @@ def semantic_search(
return results, response
+def text_to_image_search(
+ query,
+ resolution="1024*1024",
+ top_k_images=5,
+ style="探索无限") -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
+ """
+ Send a prompt text and corresponding parameters to the REST API
+ """
+ url = f"{API_ENDPOINT}/{IMAGE_REQUEST}"
+ params = {
+ "TextToImageGenerator": {
+ "style": style,
+ "topk": top_k_images,
+ "resolution": resolution,
+ }
+ }
+ req = {"query": query, "params": params}
+ response_raw = requests.post(url, json=req)
+
+ if response_raw.status_code >= 400 and response_raw.status_code != 503:
+ raise Exception(f"{vars(response_raw)}")
+
+ response = response_raw.json()
+ if "errors" in response:
+ raise Exception(", ".join(response["errors"]))
+ results = response["answers"]
+ return results, response
+
+
def send_feedback(query, answer_obj, is_correct_answer, is_correct_document,
document) -> None:
"""
diff --git a/pipelines/ui/webapp_text_to_image.py b/pipelines/ui/webapp_text_to_image.py
new file mode 100644
index 000000000000..60bd521b5e34
--- /dev/null
+++ b/pipelines/ui/webapp_text_to_image.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import os
+import argparse
+
+from PIL import Image
+from utils import text_to_image_search
+import gradio as gr
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument("--serving_port", default=8502, type=int, help="Port for the serving.")
+args = parser.parse_args()
+# yapf: enable
+
+
+def infer(text_prompt, top_k_images, Size, style):
+ results, raw_json = text_to_image_search(text_prompt,
+ resolution=Size,
+ top_k_images=top_k_images,
+ style=style)
+ return results
+
+
+def main():
+ block = gr.Blocks()
+
+ with block:
+ with gr.Group():
+ with gr.Box():
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
+ text_prompt = gr.Textbox(
+ label="Enter your prompt",
+ value='宁静的小镇',
+ show_label=False,
+ max_lines=1,
+ placeholder="Enter your prompt",
+ ).style(
+ border=(True, False, True, True),
+ rounded=(True, False, False, True),
+ container=False,
+ )
+ btn = gr.Button("开始生成").style(
+ margin=False,
+ rounded=(False, True, True, False),
+ )
+ gallery = gr.Gallery(label="Generated images",
+ show_label=False,
+ elem_id="gallery").style(grid=[2],
+ height="auto")
+
+ advanced_button = gr.Button("Advanced options",
+ elem_id="advanced-btn")
+
+ with gr.Row(elem_id="advanced-options"):
+ top_k_images = gr.Slider(label="Images",
+ minimum=1,
+ maximum=50,
+ value=5,
+ step=1)
+ style = gr.Radio(label='Style',
+ value='古风',
+ choices=[
+ '古风', '油画', '卡通画', '二次元', "水彩画", "浮世绘",
+ "蒸汽波艺术", "low poly", "像素风格", "概念艺术",
+ "未来主义", "赛博朋克", "写实风格", "洛丽塔风格", "巴洛克风格",
+ "超现实主义", "探索无限"
+ ])
+ Size = gr.Radio(label='Size',
+ value='1024*1024',
+ choices=['1024*1024', '1024*1536', '1536*1024'])
+
+ text_prompt.submit(infer,
+ inputs=[text_prompt, top_k_images, Size, style],
+ outputs=gallery)
+ btn.click(infer,
+ inputs=[text_prompt, top_k_images, Size, style],
+ outputs=gallery)
+ advanced_button.click(
+ None,
+ [],
+ text_prompt,
+ )
+ return block
+
+
+if __name__ == "__main__":
+ block = main()
+ block.launch(server_name='0.0.0.0',
+ server_port=args.serving_port,
+ share=False)