From de3f47585184995ba4f9866227858d55ac2a1947 Mon Sep 17 00:00:00 2001 From: "candy.dc" Date: Wed, 7 Sep 2022 17:34:03 +0800 Subject: [PATCH] [Embedding] Support immutable EmbeddingVariable in inference mode. --- .../framework/embedding/embedding_filter.h | 13 +++++++++ .../core/framework/embedding/embedding_var.h | 9 ++++++ .../embedding/multilevel_embedding.h | 12 ++++++++ tensorflow/core/kernels/kv_variable_ops.cc | 21 +++++++++++++- tensorflow/core/ops/kv_variable_ops.cc | 2 ++ .../python/ops/embedding_variable_ops_test.py | 29 +++++++++++++++++++ 6 files changed, 85 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/embedding/embedding_filter.h b/tensorflow/core/framework/embedding/embedding_filter.h index 3767c8af853..267b12a214e 100644 --- a/tensorflow/core/framework/embedding/embedding_filter.h +++ b/tensorflow/core/framework/embedding/embedding_filter.h @@ -50,6 +50,19 @@ class EmbeddingFilter { public: virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr, ValuePtr** value_ptr, int count, const V* default_value_no_permission) = 0; + + virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr, + const V* default_value_no_permission) { + ValuePtr* value_ptr = nullptr; + Status s = ev->LookupKey(key, &value_ptr); + if (s.ok()) { + V* mem_val = ev->LookupPrimaryEmb(value_ptr); + memcpy(val, mem_val, sizeof(V) * ev->ValueLen()); + } else { + memcpy(val, default_value_no_permission, sizeof(V) * ev->ValueLen()); + } + } + virtual Status LookupOrCreateKey(K key, ValuePtr** val, bool* is_filter) = 0; virtual void CreateGPUBatch(V* val_base, V** default_values, int64 size, int64 slice_elems, int64 value_len_, bool* init_flags, V** memcpy_address) = 0; diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 74f6e70be84..149bc3b7484 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -121,6 +121,10 @@ class EmbeddingVar : public ResourceBase { return is_initialized_; } + Status LookupKey(K key, ValuePtr** value_ptr) { + return storage_manager_->Get(key, value_ptr); + } + Status LookupOrCreateKey(K key, ValuePtr** value_ptr, bool* is_filter) { return filter_->LookupOrCreateKey(key, value_ptr, is_filter); } @@ -159,6 +163,11 @@ class EmbeddingVar : public ResourceBase { return filter_->GetFreq(key); } + void Lookup(K key, V* val, V* default_v) { + const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v; + filter_->Lookup(this, key, val, default_value_ptr, default_value_no_permission_); + } + void LookupOrCreate(K key, V* val, V* default_v, int count = 1) { const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v; ValuePtr* value_ptr = nullptr; diff --git a/tensorflow/core/framework/embedding/multilevel_embedding.h b/tensorflow/core/framework/embedding/multilevel_embedding.h index 4588161c923..a6868053911 100644 --- a/tensorflow/core/framework/embedding/multilevel_embedding.h +++ b/tensorflow/core/framework/embedding/multilevel_embedding.h @@ -243,6 +243,18 @@ class StorageManager { } } + Status Get(K key, ValuePtr** value_ptr) { + Status s; + int level = 0; + for (; level < hash_table_count_; ++level) { + s = kvs_[level].first->Lookup(key, value_ptr); + if (s.ok()) { + break; + } + } + return s; + } + Status GetOrCreate(K key, ValuePtr** value_ptr, size_t size) { bool found = false; int level = 0; diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 381ba891cd9..df50e7c1a4c 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -50,6 +50,7 @@ namespace tensorflow { namespace { const int64 kEmbeddingVarUseDB = -214; const int64 kInitializableEmbeddingVarUseDB = -215; +const char* kInferenceMode = "INFERENCE_MODE"; } #define REGISTER_KV_VAR_HANDLE(ktype, vtype) \ @@ -438,6 +439,10 @@ template class KvResourceGatherOp : public OpKernel { public: explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_)); + bool is_inference; + TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference)); + is_inference_ |= is_inference; OP_REQUIRES_OK(c, c->GetAttr("is_use_default_value_tensor", &is_use_default_value_tensor_)); @@ -461,6 +466,17 @@ class KvResourceGatherOp : public OpKernel { return 1; }; } + if (!is_inference_) { + lookup_fn_ = [](EmbeddingVar* ev, TKey key, + TValue* val, TValue* default_v, int count) { + ev->LookupOrCreate(key, val, default_v, count); + }; + } else { + lookup_fn_ = [](EmbeddingVar* ev, TKey key, + TValue* val, TValue* default_v, int count) { + ev->Lookup(key, val, default_v); + }; + } } void Compute(OpKernelContext* c) override { @@ -511,7 +527,7 @@ class KvResourceGatherOp : public OpKernel { default_v, indices_flat(i), i, ev->GetDefaultValueDim(), ev->ValueLen()); int32 count = get_count_fn_(counts, i); - ev->LookupOrCreate(indices_flat(i), + lookup_fn_(ev, indices_flat(i), out_base + i * slice_elems, default_v_ptr, count); } }; @@ -530,9 +546,12 @@ class KvResourceGatherOp : public OpKernel { private: bool is_use_default_value_tensor_; + bool is_inference_; std::function< TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_; std::function get_count_fn_; + std::function* ev, + TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_; }; #define REGISTER_GATHER_FULL(dev, ktype, vtype) \ diff --git a/tensorflow/core/ops/kv_variable_ops.cc b/tensorflow/core/ops/kv_variable_ops.cc index 6cc5a6974d3..234f9c91978 100644 --- a/tensorflow/core/ops/kv_variable_ops.cc +++ b/tensorflow/core/ops/kv_variable_ops.cc @@ -234,6 +234,7 @@ REGISTER_OP("KvResourceGatherV1") .Input("counts: counts_type") .Attr("validate_indices: bool = true") .Attr("is_use_default_value_tensor: bool = false") + .Attr("is_inference: bool = false") .Output("output: dtype") .Attr("dtype: type") .Attr("Tkeys: {int64,int32,string}") @@ -284,6 +285,7 @@ REGISTER_OP("KvResourceGather") .Output("output: dtype") .Attr("dtype: type") .Attr("Tkeys: {int64,int32,string}") + .Attr("is_inference: bool = false") .SetShapeFn([](InferenceContext* c) { ShapeAndType handle_shape_and_type; TF_RETURN_IF_ERROR( diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index a5aaf07ec04..f80f9f477a2 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -16,6 +16,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import string_ops @@ -2236,6 +2237,34 @@ def testEmbeddingVariableForGetFrequencyAndVersion(self): self.assertAllEqual(np.array([3,1,2,0,2,0,1]), f) self.assertAllEqual(np.array([2,0,1,0,2,0,2]), v) + def testEmbeddingVariableForInference(self): + print("testEmbeddingVariableForInference") + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3, + initializer=init_ops.ones_initializer(dtypes.float32), + ev_option = variables.EmbeddingVariableOption( + filter_option=variables.CounterFilter(filter_freq=3), + evict_option=variables.GlobalStepEvict(steps_to_live=2)) + ) + shape=var.get_dynamic_shape() + ids = array_ops.placeholder(dtype=dtypes.int64, name='ids') + emb = embedding_ops.embedding_lookup(var, ids) + # modify graph for infer + # emb.op.inputs[0].op.inputs[0].op._set_attr("is_inference", attr_value_pb2.AttrValue(b=True)) + # set environment + os.environ["INFERENCE_MODE"] = "1" + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + init = variables.global_variables_initializer() + with self.test_session() as sess: + sess.run([init]) + sess.run([emb, loss], feed_dict={'ids:0': [1,2,3]}) + sess.run([emb, loss], feed_dict={'ids:0': [1,3,5]}) + sess.run([emb, loss], feed_dict={'ids:0': [1,5,7]}) + s = sess.run(shape) + self.assertAllEqual(np.array([0,3]), s) + del os.environ["INFERENCE_MODE"] + ''' @test_util.run_gpu_only def testEmbeddingVariableForHBMandDRAM(self):