Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow/core/framework/embedding/embedding_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ class EmbeddingFilter {
public:
virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr,
ValuePtr<V>** value_ptr, int count, const V* default_value_no_permission) = 0;

virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr,
const V* default_value_no_permission) {
ValuePtr<V>* value_ptr = nullptr;
Status s = ev->LookupKey(key, &value_ptr);
if (s.ok()) {
V* mem_val = ev->LookupPrimaryEmb(value_ptr);
memcpy(val, mem_val, sizeof(V) * ev->ValueLen());
} else {
memcpy(val, default_value_no_permission, sizeof(V) * ev->ValueLen());
}
}

virtual Status LookupOrCreateKey(K key, ValuePtr<V>** val, bool* is_filter) = 0;
virtual void CreateGPUBatch(V* val_base, V** default_values, int64 size,
int64 slice_elems, int64 value_len_, bool* init_flags, V** memcpy_address) = 0;
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class EmbeddingVar : public ResourceBase {
return is_initialized_;
}

Status LookupKey(K key, ValuePtr<V>** value_ptr) {
return storage_manager_->Get(key, value_ptr);
}

Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, bool* is_filter) {
return filter_->LookupOrCreateKey(key, value_ptr, is_filter);
}
Expand Down Expand Up @@ -159,6 +163,11 @@ class EmbeddingVar : public ResourceBase {
return filter_->GetFreq(key);
}

void Lookup(K key, V* val, V* default_v) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
filter_->Lookup(this, key, val, default_value_ptr, default_value_no_permission_);
}

void LookupOrCreate(K key, V* val, V* default_v, int count = 1) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
ValuePtr<V>* value_ptr = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/framework/embedding/multilevel_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ class StorageManager {
}
}

Status Get(K key, ValuePtr<V>** value_ptr) {
Status s;
int level = 0;
for (; level < hash_table_count_; ++level) {
s = kvs_[level].first->Lookup(key, value_ptr);
if (s.ok()) {
break;
}
}
return s;
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr, size_t size) {
bool found = false;
int level = 0;
Expand Down
21 changes: 20 additions & 1 deletion tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace tensorflow {
namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
const char* kInferenceMode = "INFERENCE_MODE";
}

#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
Expand Down Expand Up @@ -438,6 +439,10 @@ template <typename TKey, typename TValue>
class KvResourceGatherOp : public OpKernel {
public:
explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_));
bool is_inference;
TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference));
is_inference_ |= is_inference;
OP_REQUIRES_OK(c,
c->GetAttr("is_use_default_value_tensor",
&is_use_default_value_tensor_));
Expand All @@ -461,6 +466,17 @@ class KvResourceGatherOp : public OpKernel {
return 1;
};
}
if (!is_inference_) {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->LookupOrCreate(key, val, default_v, count);
};
} else {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->Lookup(key, val, default_v);
};
}
}

void Compute(OpKernelContext* c) override {
Expand Down Expand Up @@ -511,7 +527,7 @@ class KvResourceGatherOp : public OpKernel {
default_v, indices_flat(i), i, ev->GetDefaultValueDim(),
ev->ValueLen());
int32 count = get_count_fn_(counts, i);
ev->LookupOrCreate(indices_flat(i),
lookup_fn_(ev, indices_flat(i),
out_base + i * slice_elems, default_v_ptr, count);
}
};
Expand All @@ -530,9 +546,12 @@ class KvResourceGatherOp : public OpKernel {

private:
bool is_use_default_value_tensor_;
bool is_inference_;
std::function<
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
std::function<int32(int32*, int64)> get_count_fn_;
std::function<void(EmbeddingVar<TKey, TValue>* ev,
TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
};

#define REGISTER_GATHER_FULL(dev, ktype, vtype) \
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/ops/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ REGISTER_OP("KvResourceGatherV1")
.Input("counts: counts_type")
.Attr("validate_indices: bool = true")
.Attr("is_use_default_value_tensor: bool = false")
.Attr("is_inference: bool = false")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
Expand Down Expand Up @@ -284,6 +285,7 @@ REGISTER_OP("KvResourceGather")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
.Attr("is_inference: bool = false")
.SetShapeFn([](InferenceContext* c) {
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
Expand Down
29 changes: 29 additions & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
Expand Down Expand Up @@ -2236,6 +2237,34 @@ def testEmbeddingVariableForGetFrequencyAndVersion(self):
self.assertAllEqual(np.array([3,1,2,0,2,0,1]), f)
self.assertAllEqual(np.array([2,0,1,0,2,0,2]), v)

def testEmbeddingVariableForInference(self):
print("testEmbeddingVariableForInference")
var = variable_scope.get_embedding_variable("var_1",
embedding_dim = 3,
initializer=init_ops.ones_initializer(dtypes.float32),
ev_option = variables.EmbeddingVariableOption(
filter_option=variables.CounterFilter(filter_freq=3),
evict_option=variables.GlobalStepEvict(steps_to_live=2))
)
shape=var.get_dynamic_shape()
ids = array_ops.placeholder(dtype=dtypes.int64, name='ids')
emb = embedding_ops.embedding_lookup(var, ids)
# modify graph for infer
# emb.op.inputs[0].op.inputs[0].op._set_attr("is_inference", attr_value_pb2.AttrValue(b=True))
# set environment
os.environ["INFERENCE_MODE"] = "1"
fun = math_ops.multiply(emb, 2.0, name='multiply')
loss = math_ops.reduce_sum(fun, name='reduce_sum')
init = variables.global_variables_initializer()
with self.test_session() as sess:
sess.run([init])
sess.run([emb, loss], feed_dict={'ids:0': [1,2,3]})
sess.run([emb, loss], feed_dict={'ids:0': [1,3,5]})
sess.run([emb, loss], feed_dict={'ids:0': [1,5,7]})
s = sess.run(shape)
self.assertAllEqual(np.array([0,3]), s)
del os.environ["INFERENCE_MODE"]

'''
@test_util.run_gpu_only
def testEmbeddingVariableForHBMandDRAM(self):
Expand Down