Skip to content

Commit

Permalink
Rework on PrometheusDataset (#402)
Browse files Browse the repository at this point in the history
This is part of the effort to reduce the dedicated C++
implementation of Dataset and replace with primitive ops
that could be used both with tf.data, and with Tensor.

There are some room for enhancement, for example,
a timestamp could be passed to read_prometheus
and each call will only read a small slice of the data
to tensor. Will have follow up PRs later to implement that.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang authored Aug 4, 2019
1 parent 8af0a01 commit c855890
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 137 deletions.
4 changes: 1 addition & 3 deletions tensorflow_io/prometheus/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ load(
cc_library(
name = "prometheus_ops",
srcs = [
#"//tensorflow_io/prometheus/go:prometheus.a",
#"//tensorflow_io/prometheus/go:prometheus.h",
"kernels/prometheus_input.cc",
"kernels/prometheus_kernels.cc",
"ops/prometheus_ops.cc",
],
copts = tf_io_copts(),
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_io/prometheus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
"""PrometheusInput
@@PrometheusDataset
@@read_prometheus
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow_io.prometheus.python.ops.prometheus_ops import PrometheusDataset
from tensorflow_io.prometheus.python.ops.prometheus_ops import read_prometheus

from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = [
"PrometheusDataset",
"read_prometheus",
]

remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
24 changes: 13 additions & 11 deletions tensorflow_io/prometheus/go/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,28 @@ import (
)

//export Query
func Query(endpoint string, query string, sec int64, offset int64, key []int64, val []float64) int64 {
func Query(endpoint string, query string, ts int64, timestamp []int64, value []float64) int {
client, err := api.NewClient(api.Config{
Address: endpoint,
})
if err != nil {
return -1
}
value, err := v1.NewAPI(client).Query(context.Background(), query, time.Unix(sec, 0))
v, err := v1.NewAPI(client).Query(context.Background(), query, time.Unix(ts, 0))
if err != nil {
return -1
}
if m, ok := value.(model.Matrix); ok && m.Len() > 0 {
index := int64(0)
for index < int64(len(key)) && offset+index < int64(len(m[0].Values)) {
v := m[0].Values[offset+index]
key[index] = v.Timestamp.Unix()
val[index] = float64(v.Value)
index++
if m, ok := v.(model.Matrix); ok && m.Len() > 0 {
if len(timestamp) >= len(m[0].Values) && len(value) == len(m[0].Values) {

for i := 0; i < len(m[0].Values); i++ {
v := m[0].Values[i]
timestamp[i] = int64(v.Timestamp)
value[i] = float64(v.Value)
}
}
return index

return len(m[0].Values)
}
return 0
}
Expand All @@ -42,7 +44,7 @@ func main() {
val := make([]float64, 20, 20)
sec := time.Now().Unix()
fmt.Println(sec)
returned := Query("http://localhost:9090", "coredns_dns_request_count_total[5m]", sec, 0, key, val)
returned := Query("http://localhost:9090", "coredns_dns_request_count_total[5m]", sec, key, val)
fmt.Println(returned)
for i := range key {
fmt.Printf("%d, %q, %v\n", i, model.TimeFromUnix(key[i]).Time(), val[i])
Expand Down
83 changes: 0 additions & 83 deletions tensorflow_io/prometheus/kernels/prometheus_input.cc

This file was deleted.

77 changes: 77 additions & 0 deletions tensorflow_io/prometheus/kernels/prometheus_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "go/prometheus.h"

namespace tensorflow {
namespace data {
namespace {

class ReadPrometheusOp : public OpKernel {
public:
explicit ReadPrometheusOp(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor& endpoint_tensor = context->input(0);
const string& endpoint = endpoint_tensor.scalar<string>()();

const Tensor& query_tensor = context->input(1);
const string& query = query_tensor.scalar<string>()();

int64 ts = time(NULL);

GoString endpoint_go = {endpoint.c_str(), static_cast<int64>(endpoint.size())};
GoString query_go = {query.c_str(), static_cast<int64>(query.size())};

GoSlice timestamp_go = {0, 0, 0};
GoSlice value_go = {0, 0, 0};

GoInt returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go);
OP_REQUIRES(context, returned >= 0, errors::InvalidArgument("unable to query prometheus"));

TensorShape output_shape({static_cast<int64>(returned)});

Tensor* timestamp_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &timestamp_tensor));
Tensor* value_tensor;
OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &value_tensor));

if (returned > 0) {
timestamp_go.data = timestamp_tensor->flat<int64>().data();
timestamp_go.len = returned;
timestamp_go.cap = returned;
value_go.data = value_tensor->flat<double>().data();
value_go.len = returned;
value_go.cap = returned;

returned = Query(endpoint_go, query_go, ts, timestamp_go, value_go);
OP_REQUIRES(context, returned >= 0, errors::InvalidArgument("unable to query prometheus to get the value"));
}
}
private:
mutex mu_;
Env* env_ GUARDED_BY(mu_);
};

REGISTER_KERNEL_BUILDER(Name("ReadPrometheus").Device(DEVICE_CPU),
ReadPrometheusOp);


} // namespace
} // namespace data
} // namespace tensorflow
25 changes: 6 additions & 19 deletions tensorflow_io/prometheus/ops/prometheus_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,14 @@ limitations under the License.

namespace tensorflow {

REGISTER_OP("PrometheusInput")
.Input("source: string")
.Output("handle: variant")
.Attr("filters: list(string) = []")
.Attr("columns: list(string) = []")
.Attr("schema: string = ''")
REGISTER_OP("ReadPrometheus")
.Input("endpoint: string")
.Input("query: string")
.Output("timestamp: int64")
.Output("value: float64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
return Status::OK();
});

REGISTER_OP("PrometheusDataset")
.Input("input: T")
.Input("batch: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("T: {string, variant} = DT_VARIANT")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({}));
c->set_output(1, c->MakeShape({c->UnknownDim()}));
return Status::OK();
});

Expand Down
40 changes: 24 additions & 16 deletions tensorflow_io/prometheus/python/ops/prometheus_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,37 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_io.core.python.ops import data_ops as data_ops
from tensorflow_io.core.python.ops import core_ops as prometheus_ops
from tensorflow_io.core.python.ops import data_ops
from tensorflow_io.core.python.ops import core_ops

class PrometheusDataset(data_ops.Dataset):
"""A Prometheus Dataset
"""
def read_prometheus(endpoint, query):
"""read_prometheus"""
return core_ops.read_prometheus(endpoint, query)

def __init__(self, endpoint, schema=None, batch=None):
"""Create a Prometheus Reader.
class PrometheusDataset(data_ops.BaseDataset):
"""A Prometheus Dataset"""

def __init__(self, endpoint, query):
"""Create a Prometheus Dataset
Args:
endpoint: A `tf.string` tensor containing address of
the prometheus server.
schema: A `tf.string` tensor containing the query
query: A `tf.string` tensor containing the query
string.
batch: Size of the batch.
"""
batch = 0 if batch is None else batch
dtypes = [tf.int64, tf.float64]
shapes = [
tf.TensorShape([]), tensorflow.TensorShape([])] if batch == 0 else [
tf.TensorShape([None]), tf.TensorShape([None])]
shapes = [tf.TensorShape([None]), tf.TensorShape([None])]
# TODO: It could be possible to improve the performance
# by reading a small chunk of the data while at the same
# time allowing reuse of read_prometheus. Essentially
# read_prometheus could take a timestamp and read small chunk
# at a time until running out of data.
timestamp, value = read_prometheus(endpoint, query)
timestamp_dataset = data_ops.BaseDataset.from_tensors(timestamp)
value_dataset = data_ops.BaseDataset.from_tensors(value)
dataset = data_ops.BaseDataset.zip((timestamp_dataset, value_dataset))

self._dataset = dataset
super(PrometheusDataset, self).__init__(
prometheus_ops.prometheus_dataset,
prometheus_ops.prometheus_input(endpoint, schema=schema),
batch, dtypes, shapes)
self._dataset._variant_tensor, dtypes, shapes) # pylint: disable=protected-access
16 changes: 11 additions & 5 deletions tests/test_prometheus_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@
pytest.skip(
"prometheus is not supported on macOS yet", allow_module_level=True)

def test_prometheus_input():
"""test_prometheus_input
"""
def test_prometheus():
"""test_prometheus"""
for _ in range(6):
subprocess.call(["dig", "@localhost", "-p", "1053", "www.google.com"])
time.sleep(1)
time.sleep(2)
prometheus_dataset = prometheus_io.PrometheusDataset(
"http://localhost:9090",
schema="coredns_dns_request_count_total[5s]",
batch=2)
"coredns_dns_request_count_total[5s]").apply(
tf.data.experimental.unbatch()).batch(2)

i = 0
for k, v in prometheus_dataset:
print("K, V: ", k.numpy(), v.numpy())
Expand All @@ -52,5 +52,11 @@ def test_prometheus_input():
i += 2
assert i == 6

timestamp, value = prometheus_io.read_prometheus(
"http://localhost:9090",
"coredns_dns_request_count_total[5s]")
assert timestamp.shape == [5]
assert value.shape == [5]

if __name__ == "__main__":
test.main()

0 comments on commit c855890

Please sign in to comment.