From 261fff98e511b8dbe961bcde86144c8d9eeab547 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sun, 21 Jul 2024 22:48:24 +0800 Subject: [PATCH 01/13] feat: support some Eval script flags --- src/cluster/cluster.cc | 26 ++- src/cluster/cluster.h | 3 +- src/server/worker.cc | 2 +- src/storage/scripting.cc | 104 +++++----- src/storage/scripting.h | 138 +++++++++++++- tests/gocase/unit/scripting/scripting_test.go | 180 ++++++++++++++++++ 6 files changed, 403 insertions(+), 50 deletions(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index a38c086897d..cccb1758521 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -824,9 +824,9 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const { } Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, - redis::Connection *conn) { + redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector keys_indexes; - + std::cout << "CanExecByMySelf\n"; // No keys if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK()) return Status::OK(); @@ -849,6 +849,22 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons return {Status::RedisClusterDown, "Hash slot not served"}; } + bool cross_slot_ok = false; + if (script_run_ctx) { + std::cout << "Check script_run_ctx\n"; + if(script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) { + if(getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { + return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + } + if(!(script_run_ctx->flags & lua::ScriptFlags::kScriptAllowCrossSlotKeys)) { + return {Status::RedisCrossSlot, "Script attempted to access keys that do not hash to the same slot"}; + } + } + + script_run_ctx->current_slot = slot; + cross_slot_ok = true; + } + if (myself_ && myself_ == slots_nodes_[slot]) { // We use central controller to manage the topology of the cluster. // Server can't change the topology directly, so we record the migrated slots @@ -886,8 +902,12 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons conn->IsFlagEnabled(redis::Connection::kReadOnly)) { return Status::OK(); // My master is serving this slot } + + if(!cross_slot_ok) { + return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + } - return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + return Status::OK(); } // Only HARD mode is meaningful to the Kvrocks cluster, diff --git a/src/cluster/cluster.h b/src/cluster/cluster.h index e595666c75e..468c154d4d8 100644 --- a/src/cluster/cluster.h +++ b/src/cluster/cluster.h @@ -35,6 +35,7 @@ #include "redis_slot.h" #include "server/redis_connection.h" #include "status.h" +#include "storage/scripting.h" class ClusterNode { public: @@ -83,7 +84,7 @@ class Cluster { bool IsNotMaster(); bool IsWriteForbiddenSlot(int slot) const; Status CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, - redis::Connection *conn); + redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx = nullptr); Status SetMasterSlaveRepl(); Status MigrateSlotRange(const SlotRange &slot_range, const std::string &dst_node_id, SyncMigrateContext *blocking_ctx = nullptr); diff --git a/src/server/worker.cc b/src/server/worker.cc index 22054e1faf8..0b37bcef30d 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -72,7 +72,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv), base_(event_base_new()) LOG(INFO) << "[worker] Listening on: " << bind << ":" << *port; } } - lua_ = lua::CreateState(srv, true); + lua_ = lua::CreateState(srv); } Worker::~Worker() { diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 987f9f0dfc0..a5e53a3e8bc 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -57,11 +57,11 @@ enum { namespace lua { -lua_State *CreateState(Server *srv, bool read_only) { +lua_State *CreateState(Server *srv) { lua_State *lua = lua_open(); LoadLibraries(lua); RemoveUnsupportedFunctions(lua); - LoadFuncs(lua, read_only); + LoadFuncs(lua); lua_pushlightuserdata(lua, srv); lua_setglobal(lua, REDIS_LUA_SERVER_PTR); @@ -75,7 +75,7 @@ void DestroyState(lua_State *lua) { lua_close(lua); } -void LoadFuncs(lua_State *lua, bool read_only) { +void LoadFuncs(lua_State *lua) { lua_newtable(lua); /* redis.call */ @@ -127,11 +127,6 @@ void LoadFuncs(lua_State *lua, bool read_only) { lua_pushcfunction(lua, RedisStatusReplyCommand); lua_settable(lua, -3); - /* redis.read_only */ - lua_pushstring(lua, "read_only"); - lua_pushboolean(lua, read_only); - lua_settable(lua, -3); - /* redis.register_function */ lua_pushstring(lua, "register_function"); lua_pushcfunction(lua, RedisRegisterFunction); @@ -182,6 +177,15 @@ void LoadFuncs(lua_State *lua, bool read_only) { lua_pcall(lua, 0, 0, 0); } +void LoadScriptFlags(lua_State *lua, uint64_t flags) { + std::cout << "LoadScriptFlags:" << flags << '\n'; + lua_getglobal(lua, "redis"); + lua_pushstring(lua, "script_flags"); + lua_pushinteger(lua, static_cast(flags)); + lua_settable(lua, -3); + lua_pop(lua, 1); + stackDump(lua); +} int RedisLogCommand(lua_State *lua) { int argc = lua_gettop(lua); @@ -226,8 +230,8 @@ int RedisLogCommand(lua_State *lua) { int RedisRegisterFunction(lua_State *lua) { int argc = lua_gettop(lua); - if (argc != 2) { - lua_pushstring(lua, "redis.register_function() requires two arguments."); + if (argc < 2) { + lua_pushstring(lua, "redis.register_function() requires at least two arguments."); return lua_error(lua); } @@ -288,31 +292,10 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee return {Status::NotOK, "Expect a Shebang statement in the first line"}; } - static constexpr const char *shebang_prefix = "#!lua"; - static constexpr const char *shebang_libname_prefix = "name="; - - auto first_line_split = util::Split(first_line, " \r\t"); - if (first_line_split.empty() || first_line_split[0] != shebang_prefix) { - return {Status::NotOK, "Expect a Shebang statement in the first line, e.g. `#!lua name=mylib`"}; - } - - size_t libname_pos = 1; - for (; libname_pos < first_line_split.size(); ++libname_pos) { - if (util::HasPrefix(first_line_split[libname_pos], shebang_libname_prefix)) { - break; - } - } - - if (libname_pos >= first_line_split.size()) { - return {Status::NotOK, "Expect library name in the Shebang statement, e.g. `#!lua name=mylib`"}; - } + ShebangParser parser(first_line); + if (auto s = parser.Parse(); !s.IsOK()) return s; + auto libname = parser.GetLibName(); - auto libname = first_line_split[libname_pos].substr(strlen(shebang_libname_prefix)); - *lib_name = libname; - if (libname.empty() || - std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { - return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; - } auto srv = conn->GetServer(); auto lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -590,7 +573,6 @@ Status FunctionDelete(Server *srv, const std::string &name) { Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sha, const std::vector &keys, const std::vector &argv, bool evalsha, std::string *output, bool read_only) { Server *srv = conn->GetServer(); - // Use the worker's private Lua VM when entering the read-only mode lua_State *lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -612,6 +594,8 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh /* Try to lookup the Lua function */ lua_getglobal(lua, funcname); + std::cout << "Try to lookup the Lua function\n"; + stackDump(lua); if (lua_isnil(lua, -1)) { lua_pop(lua, 1); /* remove the nil from the stack */ std::string body; @@ -624,6 +608,31 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh } else { body = body_or_sha; } + std::cout << "Get Body:\n" << body; + uint64_t script_flags = read_only ? ScriptFlags::kScriptNoWrites : 0; + if (auto pos = body.find('\n'); pos != std::string::npos) { + auto first_line = body.substr(0, pos); + std::cout << "\nGet First Line:" << first_line << '\n'; + + if (util::HasPrefix(first_line, "#!lua")) { + ShebangParser parser(first_line); + auto s = parser.Parse(); + if (!s.IsOK()) { + lua_pop(lua, 1); /* remove the error handler from the stack. */ + return s; + } + script_flags |= parser.GetFlags(); + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys; + } + } + + ScriptRunCtx script_run_ctx; + script_run_ctx.flags = script_flags; + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); + // LoadScriptFlags(lua, script_flags); std::string sha = funcname + 2; auto s = CreateFunction(srv, body, &sha, lua, false); @@ -645,8 +654,12 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh * EVAL received. */ SetGlobalArray(lua, "KEYS", keys); SetGlobalArray(lua, "ARGV", argv); - + // int errfunc_index = + std::cout << "Before EvalGenericCommand lua_pcall\n"; + stackDump(lua); if (lua_pcall(lua, 0, 1, -2)) { + std::cout << "After EvalGenericCommand lua_pcall\n"; + stackDump(lua); auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1)); *output = redis::Error({Status::NotOK, msg}); lua_pop(lua, 2); @@ -701,11 +714,9 @@ Server *GetServer(lua_State *lua) { // TODO: we do not want to repeat same logic as Connection::ExecuteCommands, // so the function need to be refactored int RedisGenericCommand(lua_State *lua, int raise_error) { - lua_getglobal(lua, "redis"); - lua_getfield(lua, -1, "read_only"); - int read_only = lua_toboolean(lua, -1); - lua_pop(lua, 2); - + ScriptRunCtx *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + std::cout << "get script_flags = " << script_run_ctx->flags << '\n'; + stackDump(lua); int argc = lua_gettop(lua); if (argc == 0) { PushError(lua, "Please specify at least one argument for redis.call()"); @@ -738,7 +749,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { auto attributes = cmd->GetAttributes(); auto cmd_flags = attributes->GenerateFlags(args); - if (read_only && !(cmd_flags & redis::kCmdReadOnly)) { + if ((script_run_ctx->flags & ScriptFlags::kScriptNoWrites) && !(cmd_flags & redis::kCmdReadOnly)) { PushError(lua, "Write commands are not allowed from read-only scripts"); return raise_error ? RaiseError(lua) : 1; } @@ -760,8 +771,13 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { redis::Connection *conn = srv->GetCurrentConnection(); if (config->cluster_enabled) { - auto s = srv->cluster->CanExecByMySelf(attributes, args, conn); + if (script_run_ctx->flags & ScriptFlags::kScriptNoCluster) { + PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set"); + return raise_error ? RaiseError(lua) : 1; + } + auto s = srv->cluster->CanExecByMySelf(attributes, args, conn, script_run_ctx); if (!s.IsOK()) { + std::cout << "CanExecByMySelf failed, s = " << s.Msg() << '\n'; PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); return raise_error ? RaiseError(lua) : 1; } @@ -1340,6 +1356,8 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { [[noreturn]] int RaiseError(lua_State *lua) { lua_pushstring(lua, "err"); lua_gettable(lua, -2); + std::cout << "RaiseError\n"; + stackDump(lua); lua_error(lua); __builtin_unreachable(); } diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 3b2dd45deef..08e2f387a25 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -36,11 +36,11 @@ inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARI namespace lua { -lua_State *CreateState(Server *srv, bool read_only = false); +lua_State *CreateState(Server *srv); void DestroyState(lua_State *lua); Server *GetServer(lua_State *lua); -void LoadFuncs(lua_State *lua, bool read_only = false); +void LoadFuncs(lua_State *lua); void LoadLibraries(lua_State *lua); void RemoveUnsupportedFunctions(lua_State *lua); void EnableGlobalsProtection(lua_State *lua); @@ -101,4 +101,138 @@ void SHA1Hex(char *digest, const char *script, size_t len); int RedisMathRandom(lua_State *l); int RedisMathRandomSeed(lua_State *l); +// TODO 注释,默认值,uint=>int +enum ScriptFlags : uint64_t { + kScriptNoWrites = 1ULL << 0, // "no-writes" flag + kScriptAllowOom = 1ULL << 1, // "allow-oom" flag + kScriptAllowStale = 1ULL << 2, // "allow-stale" flag + kScriptNoCluster = 1ULL << 3, // "no-cluster" flag + kScriptAllowCrossSlotKeys = 1ULL << 4, // "allow-cross-slot-keys" flag +}; + +class ShebangParser { + public: + ShebangParser(const std::string &shebang) : shebang_(shebang) {} + + [[nodiscard]] Status Parse() { + std::cout << "Start Parse\n"; + static constexpr const char *shebang_prefix = "#!lua"; + static constexpr const char *shebang_libname_prefix = "name="; + static constexpr const char *shebang_flags_prefix = "flags="; + + if (!util::HasPrefix(shebang_, shebang_prefix)) { + return {Status::NotOK, "Expect shebang prefix \"#!lua\" at the beginning of the first line"}; + } + auto shebang_content = shebang_.substr(strlen(shebang_prefix)); + for (const auto &shebang_split : util::Split(shebang_content, " ")) { + std::cout << shebang_split << std::endl; + if (util::HasPrefix(shebang_split, shebang_libname_prefix)) { + if (!libname_.empty()) { + // TODO 已经有了 + return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; + } + libname_ = shebang_split.substr(strlen(shebang_libname_prefix)); + if (libname_.empty() || + std::any_of(libname_.begin(), libname_.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { + return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; + } + } else if (util::HasPrefix(shebang_split, shebang_flags_prefix)) { + auto flags = shebang_split.substr(strlen(shebang_flags_prefix)); + for (const auto &flag : util::Split(flags, ",")) { + if (flag == "no-writes") { + flags_ |= kScriptNoWrites; + } else if (flag == "allow-oom") { + return {Status::NotSupported, "allow-oom is not supported yet"}; + } else if (flag == "allow-stale") { + return {Status::NotSupported, "allow-stale is not supported yet"}; + } else if (flag == "no-cluster") { + flags_ |= kScriptNoCluster; + } else if (flag == "allow-cross-slot-keys") { + flags_ |= kScriptAllowCrossSlotKeys; + } else { + return {Status::NotOK, "Unexpected flag in script shebang: " + flag}; + } + } + } else { + return {Status::NotOK, "Expect a valid Shebang statement"}; + } + } + return Status::OK(); + } + + [[nodiscard]] uint64_t GetFlags() const { return flags_; } + [[nodiscard]] std::string GetLibName() const { return libname_; } + + private: + uint64_t flags_ = 0; + std::string libname_; + std::string shebang_; +}; + +inline constexpr const char *REGISTRY_SCRIPT_RUN_CTX_NAME = "SCRIPT_RUN_CTX"; +// TODO 注释 +struct ScriptRunCtx { + uint64_t flags = 0; + int current_slot = -1; +}; + +static void stackDump(lua_State *L) { + int top = lua_gettop(L); + for (auto i = top; i >= 1; i--) { /* repeat for each level */ + int t = lua_type(L, i); + printf("%d: ", i); + switch (t) { + case LUA_TSTRING: /* strings */ + printf("`%s'", lua_tostring(L, i)); + break; + + case LUA_TBOOLEAN: /* booleans */ + printf(lua_toboolean(L, i) ? "true" : "false"); + break; + + case LUA_TNUMBER: /* numbers */ + printf("%g", lua_tonumber(L, i)); + break; + default: /* other values */ + printf("%s", lua_typename(L, t)); + break; + } + printf("\n"); /* put a separator */ + } + printf("\n"); /* end the listing */ +} + +template +void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) { + lua_pushstring(lua, name); + if (ptr) { + lua_pushlightuserdata(lua, ptr); + } else { + lua_pushnil(lua); + } + lua_settable(lua, LUA_REGISTRYINDEX); +} + +template +T *GetFromRegistry(lua_State *lua, const char *name) { + lua_pushstring(lua, name); + lua_gettable(lua, LUA_REGISTRYINDEX); + + if (lua_isnil(lua, -1)) { + lua_pop(lua, 1); /* pops the value */ + return nullptr; + } + + /* must be light user data */ + CHECK(lua_islightuserdata(lua, -1)); + auto *ptr = static_cast(lua_touserdata(lua, -1)); + + CHECK_NOTNULL(ptr); + + /* pops the value */ + lua_pop(lua, 1); + + return ptr; +} + } // namespace lua diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index 50680f7d63e..7ca1a4964f5 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -607,3 +607,183 @@ func TestScriptingWithRESP3(t *testing.T) { require.EqualValues(t, []interface{}{"f1", "v1"}, vals) }) } + +func TestEvalScriptFlags(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("no-writes", func (t *testing.T) { + r := rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes + return redis.call('set', 'k1','v1');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL", `return redis.call('set', 'k2','v2');`, "0") + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "EVAL", + `#!lua + return redis.call('set', 'k3','v3');`, "0") + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "EVAL_RO", + `return redis.call('set', 'k4','v4');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL_RO", + `#!lua + return redis.call('set', 'k5','v5');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL_RO", + `#!lua flags=no-writes + return redis.call('set', 'k6','v6');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + }) + + + srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + defer func() { srv0.Close() }() + id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err()) + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + srv1Alive := true + defer func() { + if srv1Alive { + srv1.Close() + } + }() + + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0, srv0.Host(), srv0.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1, srv1.Host(), srv1.Port()) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + t.Run("no-cluster", func (t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=no-cluster + return redis.call('set', 'k','v');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // Only valid in cluster mode + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-cluster + return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + + // Scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + r = rdb0.Do(ctx, "EVAL", `return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua + return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + }) + + t.Run("allow-cross-slot-keys", func (t *testing.T) { + // Node0: bar-slot = 5061, test-slot = 6918 + // Node1: foo-slot = 12182 + // Different slots of different nodes are not affected by allow-cross-slot-keys, + // and different slots of the same node can be allowed + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + redis.call('set', 'bar','value_bar'); + return redis.call('set', 'test', 'value_test');`, "0"); + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + redis.call('set', 'foo','value_foo'); + return redis.call('set', 'bar', 'value_bar');`, "0"); + util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + + // There is a shebang prefix #!lua but crossslot is not allowed when flags are not set + r = rdb0.Do(ctx, "EVAL", + `#!lua + redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0"); + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access keys that do not hash to the same slot") + + r = rdb0.Do(ctx, "EVAL", + `#!lua + redis.call('get', 'foo'); + return redis.call('get', 'bar');`, "0"); + util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + + // Old style: CrossSlot is allowed when there is neither #!lua nor flags set + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'foo'); + return redis.call('get', 'bar');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + + // Pre-declared keys are not affected by Arlo-Cross-Slot-Keyes + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + local key = redis.call('get', KEY[1]); + return redis.call('get', KEY[2]);`, "2", "bar", "test"); + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + }) + + t.Run("invalid-flags", func (t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=invalid-flag + return redis.call('set', 'k','v');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unexpected flag in script shebang:*") + }) + + t.Run("mixed use", func (t *testing.T) { + r :=rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('get', 'key_a');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('set', 'key_a', 'value_a');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + err := rdb.Set(ctx, "key_a", "value_a", 0).Err() + require.NoError(t, err) + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('get', 'key_a');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('set', 'bar'); + return redis.call('set', 'test');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('get', 'bar'); + return redis.call('get', 'foo');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + }) +} \ No newline at end of file From e1fefbc0231d8c3c1c689e4e8dc016754239a6cf Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sun, 21 Jul 2024 22:50:43 +0800 Subject: [PATCH 02/13] style: clang-format & gofmt --- src/cluster/cluster.cc | 10 +- src/storage/scripting.h | 4 +- tests/gocase/unit/scripting/scripting_test.go | 125 +++++++++--------- 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index cccb1758521..73606833266 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -852,11 +852,11 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons bool cross_slot_ok = false; if (script_run_ctx) { std::cout << "Check script_run_ctx\n"; - if(script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) { - if(getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { + if (script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) { + if (getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; } - if(!(script_run_ctx->flags & lua::ScriptFlags::kScriptAllowCrossSlotKeys)) { + if (!(script_run_ctx->flags & lua::ScriptFlags::kScriptAllowCrossSlotKeys)) { return {Status::RedisCrossSlot, "Script attempted to access keys that do not hash to the same slot"}; } } @@ -902,8 +902,8 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons conn->IsFlagEnabled(redis::Connection::kReadOnly)) { return Status::OK(); // My master is serving this slot } - - if(!cross_slot_ok) { + + if (!cross_slot_ok) { return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; } diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 08e2f387a25..4feef5a4269 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -225,8 +225,8 @@ T *GetFromRegistry(lua_State *lua, const char *name) { /* must be light user data */ CHECK(lua_islightuserdata(lua, -1)); - auto *ptr = static_cast(lua_touserdata(lua, -1)); - + auto *ptr = static_cast(lua_touserdata(lua, -1)); + CHECK_NOTNULL(ptr); /* pops the value */ diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index 7ca1a4964f5..497834b0753 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -615,37 +615,36 @@ func TestEvalScriptFlags(t *testing.T) { ctx := context.Background() rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - - t.Run("no-writes", func (t *testing.T) { - r := rdb.Do(ctx, "EVAL", - `#!lua flags=no-writes + + t.Run("no-writes", func(t *testing.T) { + r := rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes return redis.call('set', 'k1','v1');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") r = rdb.Do(ctx, "EVAL", `return redis.call('set', 'k2','v2');`, "0") require.NoError(t, r.Err()) - r = rdb.Do(ctx, "EVAL", - `#!lua + r = rdb.Do(ctx, "EVAL", + `#!lua return redis.call('set', 'k3','v3');`, "0") require.NoError(t, r.Err()) - r = rdb.Do(ctx, "EVAL_RO", - `return redis.call('set', 'k4','v4');`, "0") + r = rdb.Do(ctx, "EVAL_RO", + `return redis.call('set', 'k4','v4');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") - r = rdb.Do(ctx, "EVAL_RO", - `#!lua + r = rdb.Do(ctx, "EVAL_RO", + `#!lua return redis.call('set', 'k5','v5');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") - r = rdb.Do(ctx, "EVAL_RO", - `#!lua flags=no-writes + r = rdb.Do(ctx, "EVAL_RO", + `#!lua flags=no-writes return redis.call('set', 'k6','v6');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") }) - srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) rdb0 := srv0.NewClient() defer func() { require.NoError(t, rdb0.Close()) }() @@ -671,119 +670,119 @@ func TestEvalScriptFlags(t *testing.T) { require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) - t.Run("no-cluster", func (t *testing.T) { - r := rdb0.Do(ctx, "EVAL", - `#!lua flags=no-cluster + t.Run("no-cluster", func(t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=no-cluster return redis.call('set', 'k','v');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") - + // Only valid in cluster mode - r = rdb.Do(ctx, "EVAL", - `#!lua flags=no-cluster + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-cluster return redis.call('set', 'k','v');`, "0") require.NoError(t, r.Err()) - // Scripts without #! can run commands that access keys belonging to different cluster hash slots, + // Scripts without #! can run commands that access keys belonging to different cluster hash slots, // but ones with #! inherit the default flags, so they cannot. r = rdb0.Do(ctx, "EVAL", `return redis.call('set', 'k','v');`, "0") require.NoError(t, r.Err()) - r = rdb0.Do(ctx, "EVAL", - `#!lua + r = rdb0.Do(ctx, "EVAL", + `#!lua return redis.call('set', 'k','v');`, "0") require.NoError(t, r.Err()) }) - t.Run("allow-cross-slot-keys", func (t *testing.T) { + t.Run("allow-cross-slot-keys", func(t *testing.T) { // Node0: bar-slot = 5061, test-slot = 6918 // Node1: foo-slot = 12182 - // Different slots of different nodes are not affected by allow-cross-slot-keys, + // Different slots of different nodes are not affected by allow-cross-slot-keys, // and different slots of the same node can be allowed - r := rdb0.Do(ctx, "EVAL", - `#!lua flags=allow-cross-slot-keys + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys redis.call('set', 'bar','value_bar'); - return redis.call('set', 'test', 'value_test');`, "0"); + return redis.call('set', 'test', 'value_test');`, "0") require.NoError(t, r.Err()) - r = rdb0.Do(ctx, "EVAL", - `#!lua flags=allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys redis.call('set', 'foo','value_foo'); - return redis.call('set', 'bar', 'value_bar');`, "0"); + return redis.call('set', 'bar', 'value_bar');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") // There is a shebang prefix #!lua but crossslot is not allowed when flags are not set - r = rdb0.Do(ctx, "EVAL", - `#!lua + r = rdb0.Do(ctx, "EVAL", + `#!lua redis.call('get', 'bar'); - return redis.call('get', 'test');`, "0"); + return redis.call('get', 'test');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access keys that do not hash to the same slot") - r = rdb0.Do(ctx, "EVAL", - `#!lua + r = rdb0.Do(ctx, "EVAL", + `#!lua redis.call('get', 'foo'); - return redis.call('get', 'bar');`, "0"); + return redis.call('get', 'bar');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") - + // Old style: CrossSlot is allowed when there is neither #!lua nor flags set - r = rdb0.Do(ctx, "EVAL", - `redis.call('get', 'bar'); + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'bar'); return redis.call('get', 'test');`, "0") require.NoError(t, r.Err()) - - r = rdb0.Do(ctx, "EVAL", - `redis.call('get', 'foo'); + + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'foo'); return redis.call('get', 'bar');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") // Pre-declared keys are not affected by Arlo-Cross-Slot-Keyes - r = rdb0.Do(ctx, "EVAL", - `#!lua flags=allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys local key = redis.call('get', KEY[1]); - return redis.call('get', KEY[2]);`, "2", "bar", "test"); + return redis.call('get', KEY[2]);`, "2", "bar", "test") require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") }) - t.Run("invalid-flags", func (t *testing.T) { - r := rdb0.Do(ctx, "EVAL", - `#!lua flags=invalid-flag + t.Run("invalid-flags", func(t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=invalid-flag return redis.call('set', 'k','v');`, "0") util.ErrorRegexp(t, r.Err(), "ERR Unexpected flag in script shebang:*") }) - t.Run("mixed use", func (t *testing.T) { - r :=rdb0.Do(ctx, "EVAL", - `#!lua flags=no-writes,no-cluster + t.Run("mixed use", func(t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster return redis.call('get', 'key_a');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") - - r = rdb.Do(ctx, "EVAL", - `#!lua flags=no-writes,no-cluster + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster return redis.call('set', 'key_a', 'value_a');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") err := rdb.Set(ctx, "key_a", "value_a", 0).Err() require.NoError(t, err) - r = rdb.Do(ctx, "EVAL", - `#!lua flags=no-writes,no-cluster + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster return redis.call('get', 'key_a');`, "0") require.NoError(t, r.Err()) - r = rdb0.Do(ctx, "EVAL", - `#!lua flags=no-writes,allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys redis.call('get', 'bar'); return redis.call('get', 'test');`, "0") require.NoError(t, r.Err()) - r = rdb0.Do(ctx, "EVAL", - `#!lua flags=no-writes,allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys redis.call('set', 'bar'); return redis.call('set', 'test');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") - r = rdb0.Do(ctx, "EVAL", - `#!lua flags=no-writes,allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys redis.call('get', 'bar'); return redis.call('get', 'foo');`, "0") util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") }) -} \ No newline at end of file +} From 8f290f20a24b6631705883536d1b9acc329cf8d8 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Thu, 25 Jul 2024 14:57:04 +0800 Subject: [PATCH 03/13] feat: support function script flags --- src/cluster/cluster.cc | 2 - src/server/server.cc | 11 + src/server/server.h | 4 + src/storage/scripting.cc | 171 +++++++++----- src/storage/scripting.h | 112 ++------- tests/gocase/unit/scripting/function_test.go | 217 +++++++++++++++++- tests/gocase/unit/scripting/scripting_test.go | 12 +- 7 files changed, 376 insertions(+), 153 deletions(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index 73606833266..5ce23954b20 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -826,7 +826,6 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const { Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector keys_indexes; - std::cout << "CanExecByMySelf\n"; // No keys if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK()) return Status::OK(); @@ -851,7 +850,6 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons bool cross_slot_ok = false; if (script_run_ctx) { - std::cout << "Check script_run_ctx\n"; if (script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) { if (getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; diff --git a/src/server/server.cc b/src/server/server.cc index 5e3b7f1363c..2b6bb581861 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -1682,6 +1682,16 @@ Status Server::ScriptSet(const std::string &sha, const std::string &body) const return storage->WriteToPropagateCF(func_name, body); } +void Server::CacheScriptFlags(const std::string &sha, uint64_t flags) { script_flags_cache_.try_emplace(sha, flags); } + +[[nodiscard]] Status Server::GetScriptFlags(const std::string &sha, uint64_t &flags) const { + if (script_flags_cache_.count(sha)) { + flags = script_flags_cache_.at(sha); + return Status::OK(); + } + return {Status::NotFound, "The flags cache of script sha does not exist, sha: " + sha}; +} + Status Server::FunctionGetCode(const std::string &lib, std::string *code) const { std::string func_name = engine::kLuaLibCodePrefix + lib; auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); @@ -1714,6 +1724,7 @@ Status Server::FunctionSetLib(const std::string &func, const std::string &lib) c void Server::ScriptReset() { auto lua = lua_.exchange(lua::CreateState(this)); + script_flags_cache_.clear(); lua::DestroyState(lua); } diff --git a/src/server/server.h b/src/server/server.h index c1793e81ce6..1b9b10a6464 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -274,6 +274,8 @@ class Server { Status ScriptSet(const std::string &sha, const std::string &body) const; void ScriptReset(); Status ScriptFlush(); + void CacheScriptFlags(const std::string &sha, uint64_t flags); + [[nodiscard]] Status GetScriptFlags(const std::string &sha, uint64_t &flags) const; Status FunctionGetCode(const std::string &lib, std::string *code) const; Status FunctionGetLib(const std::string &func, std::string *lib) const; @@ -341,6 +343,8 @@ class Server { std::mutex last_random_key_cursor_mu_; std::atomic lua_; + // The cache of flag is cached when the script is created and cleared when the script is flushed. + std::unordered_map script_flags_cache_; redis::Connection *curr_connection_ = nullptr; diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index a5e53a3e8bc..1c1b4f74aec 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -177,15 +177,6 @@ void LoadFuncs(lua_State *lua) { lua_pcall(lua, 0, 0, 0); } -void LoadScriptFlags(lua_State *lua, uint64_t flags) { - std::cout << "LoadScriptFlags:" << flags << '\n'; - lua_getglobal(lua, "redis"); - lua_pushstring(lua, "script_flags"); - lua_pushinteger(lua, static_cast(flags)); - lua_settable(lua, -3); - lua_pop(lua, 1); - stackDump(lua); -} int RedisLogCommand(lua_State *lua) { int argc = lua_gettop(lua); @@ -247,6 +238,9 @@ int RedisRegisterFunction(lua_State *lua) { // set this function to global std::string name = lua_tostring(lua, 1); + if (argc == 3) { + lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); + } lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); // set this function name to REDIS_FUNCTION_LIBRARIES[libname] @@ -278,7 +272,6 @@ int RedisRegisterFunction(lua_State *lua) { lua_pushstring(lua, "redis.register_function() failed to store informantion."); return lua_error(lua); } - return 0; } @@ -292,9 +285,8 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee return {Status::NotOK, "Expect a Shebang statement in the first line"}; } - ShebangParser parser(first_line); - if (auto s = parser.Parse(); !s.IsOK()) return s; - auto libname = parser.GetLibName(); + std::string libname; + if (auto s = ExtractLibNameFromShebang(first_line, libname); !s.IsOK()) return s; auto srv = conn->GetServer(); auto lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -392,13 +384,43 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: std::string libcode; s = srv->FunctionGetCode(libname, &libcode); if (!s) return s; - s = FunctionLoad(conn, libcode, false, false, &libname, read_only); if (!s) return s; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); } + uint64_t function_flags = read_only ? ScriptFlags::kScriptNoWrites : 0; + lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); + // script, myfunc, err_func, table + if (!lua_isnil(lua, -1)) { + int n = static_cast(lua_objlen(lua, -1)); + for (int i = 1; i <= n; ++i) { + lua_pushnumber(lua, i); + lua_gettable(lua, -2); + std::string flag = lua_tostring(lua, -1); + if (flag == "no-writes") { + function_flags |= kScriptNoWrites; + } else if (flag == "allow-oom") { + return {Status::NotSupported, "allow-oom is not supported yet"}; + } else if (flag == "allow-stale") { + return {Status::NotSupported, "allow-stale is not supported yet"}; + } else if (flag == "no-cluster") { + function_flags |= kScriptNoCluster; + } else if (flag == "allow-cross-slot-keys") { + function_flags |= kScriptAllowCrossSlotKeys; + } else { + return {Status::NotOK, "Unexpected function flag: " + flag}; + } + lua_pop(lua, 1); + } + } + lua_pop(lua, 1); + + ScriptRunCtx script_run_ctx; + script_run_ctx.flags = function_flags; + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); + PushArray(lua, keys); PushArray(lua, argv); if (lua_pcall(lua, 2, 1, -4)) { @@ -555,6 +577,8 @@ Status FunctionDelete(Server *srv, const std::string &name) { std::string func = lua_tostring(lua, -1); lua_pushnil(lua); lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + func).c_str()); + lua_pushnil(lua); + lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + func).c_str()); auto _ = storage->Delete(rocksdb::WriteOptions(), cf, engine::kLuaFuncLibPrefix + func); lua_pop(lua, 1); } @@ -594,8 +618,6 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh /* Try to lookup the Lua function */ lua_getglobal(lua, funcname); - std::cout << "Try to lookup the Lua function\n"; - stackDump(lua); if (lua_isnil(lua, -1)) { lua_pop(lua, 1); /* remove the nil from the stack */ std::string body; @@ -608,31 +630,6 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh } else { body = body_or_sha; } - std::cout << "Get Body:\n" << body; - uint64_t script_flags = read_only ? ScriptFlags::kScriptNoWrites : 0; - if (auto pos = body.find('\n'); pos != std::string::npos) { - auto first_line = body.substr(0, pos); - std::cout << "\nGet First Line:" << first_line << '\n'; - - if (util::HasPrefix(first_line, "#!lua")) { - ShebangParser parser(first_line); - auto s = parser.Parse(); - if (!s.IsOK()) { - lua_pop(lua, 1); /* remove the error handler from the stack. */ - return s; - } - script_flags |= parser.GetFlags(); - } else { - // scripts without #! can run commands that access keys belonging to different cluster hash slots, - // but ones with #! inherit the default flags, so they cannot. - script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys; - } - } - - ScriptRunCtx script_run_ctx; - script_run_ctx.flags = script_flags; - SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); - // LoadScriptFlags(lua, script_flags); std::string sha = funcname + 2; auto s = CreateFunction(srv, body, &sha, lua, false); @@ -644,6 +641,12 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh lua_getglobal(lua, funcname); } + ScriptRunCtx current_script_run_ctx; + auto s = srv->GetScriptFlags(funcname + 2, current_script_run_ctx.flags); + if (!s.IsOK()) return s; + if (read_only) current_script_run_ctx.flags |= ScriptFlags::kScriptNoWrites; + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, ¤t_script_run_ctx); + // For the Lua script, should be always run with RESP2 protocol, // unless users explicitly set the protocol version in the script via `redis.setresp`. // So we need to save the current protocol version and set it to RESP2, @@ -654,12 +657,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh * EVAL received. */ SetGlobalArray(lua, "KEYS", keys); SetGlobalArray(lua, "ARGV", argv); - // int errfunc_index = - std::cout << "Before EvalGenericCommand lua_pcall\n"; - stackDump(lua); if (lua_pcall(lua, 0, 1, -2)) { - std::cout << "After EvalGenericCommand lua_pcall\n"; - stackDump(lua); auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1)); *output = redis::Error({Status::NotOK, msg}); lua_pop(lua, 2); @@ -674,6 +672,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh lua_setglobal(lua, "KEYS"); lua_pushnil(lua); lua_setglobal(lua, "ARGV"); + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, nullptr); /* Call the Lua garbage collector from time to time to avoid a * full cycle performed by Lua, which adds too latency. @@ -714,9 +713,9 @@ Server *GetServer(lua_State *lua) { // TODO: we do not want to repeat same logic as Connection::ExecuteCommands, // so the function need to be refactored int RedisGenericCommand(lua_State *lua, int raise_error) { - ScriptRunCtx *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); - std::cout << "get script_flags = " << script_run_ctx->flags << '\n'; - stackDump(lua); + auto *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + CHECK_NOTNULL(script_run_ctx); + int argc = lua_gettop(lua); if (argc == 0) { PushError(lua, "Please specify at least one argument for redis.call()"); @@ -777,8 +776,11 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { } auto s = srv->cluster->CanExecByMySelf(attributes, args, conn, script_run_ctx); if (!s.IsOK()) { - std::cout << "CanExecByMySelf failed, s = " << s.Msg() << '\n'; - PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); + if (s.Is()) { + PushError(lua, "Script attempted to access a non local key in a cluster node script"); + } else { + PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); + } return raise_error ? RaiseError(lua) : 1; } } @@ -1356,8 +1358,6 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { [[noreturn]] int RaiseError(lua_State *lua) { lua_pushstring(lua, "err"); lua_gettable(lua, -2); - std::cout << "RaiseError\n"; - stackDump(lua); lua_error(lua); __builtin_unreachable(); } @@ -1488,8 +1488,73 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu } lua_setglobal(lua, funcname); + // Cache the flags of the current script + uint64_t script_flags = 0; + if (auto pos = body.find('\n'); pos != std::string::npos) { + auto first_line = body.substr(0, pos); + if (util::HasPrefix(first_line, "#!lua")) { + uint64_t shebang_flags = 0; + if (auto s = ExtractFlagsFromShebang(first_line, shebang_flags); !s.IsOK()) { + lua_pop(lua, 1); /* remove the error handler from the stack. */ + return s; + } + script_flags |= shebang_flags; + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys; + } + } + srv->CacheScriptFlags(*sha, script_flags); + // would store lua function into propagate column family and propagate those scripts to slaves return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK(); } +[[nodiscard]] Status ExtractLibNameFromShebang(const std::string &shebang, std::string &libname) { + static constexpr const char *shebang_prefix = "#!lua"; + static constexpr const char *shebang_libname_prefix = "name="; + + if (!util::HasPrefix(shebang, shebang_prefix)) { + return {Status::NotOK, "Expect shebang prefix \"#!lua\" at the beginning of the first line"}; + } + + if (auto pos = shebang.find(shebang_libname_prefix, strlen(shebang_prefix)); pos != std::string::npos) { + libname = shebang.substr(pos + strlen(shebang_libname_prefix)); + if (libname.empty() || + std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { + return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; + } + return Status::OK(); + } + + return {Status::NotOK, "Expect a library name in the Shebang statement"}; +} + +[[nodiscard]] Status ExtractFlagsFromShebang(const std::string &shebang, uint64_t &flags) { + static constexpr const char *shebang_prefix = "#!lua"; + static constexpr const char *shebang_flags_prefix = "flags="; + + if (auto pos = shebang.find(shebang_flags_prefix, strlen(shebang_prefix)); pos != std::string::npos) { + auto flags_content = shebang.substr(pos + strlen(shebang_flags_prefix)); + flags = 0; + for (const auto &flag : util::Split(flags_content, ",")) { + if (flag == "no-writes") { + flags |= kScriptNoWrites; + } else if (flag == "allow-oom") { + return {Status::NotSupported, "allow-oom is not supported yet"}; + } else if (flag == "allow-stale") { + return {Status::NotSupported, "allow-stale is not supported yet"}; + } else if (flag == "no-cluster") { + flags |= kScriptNoCluster; + } else if (flag == "allow-cross-slot-keys") { + flags |= kScriptAllowCrossSlotKeys; + } else { + return {Status::NotOK, "Unexpected flag in script shebang: " + flag}; + } + } + } + return Status::OK(); +} + } // namespace lua diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 4feef5a4269..d789aba7ad2 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -29,10 +29,12 @@ inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = "__redis_registered_"; +inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_registered_flags_"; inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr"; inline constexpr const char REDIS_FUNCTION_LIBNAME[] = "REDIS_FUNCTION_LIBNAME"; inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = "REDIS_FUNCTION_NEEDSTORE"; -inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARIES"; +inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LxIBRARIES"; +inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX"; namespace lua { @@ -101,7 +103,13 @@ void SHA1Hex(char *digest, const char *script, size_t len); int RedisMathRandom(lua_State *l); int RedisMathRandomSeed(lua_State *l); -// TODO 注释,默认值,uint=>int +/// ScriptFlags turn on/off constraints or indicate properties in Eval scripts and functions +/// +/// Note: The default for Eval scripts are different than the default for functions(default is 0). +/// As soon as Redis sees the #! comment, it'll treat the script as if it declares flags, even if no flags are defined, +/// it still has a different set of defaults compared to a script without a #! line. +/// Another difference is that scripts without #! can run commands that access keys belonging to different cluster hash +/// slots, but ones with #! inherit the default flags, so they cannot. enum ScriptFlags : uint64_t { kScriptNoWrites = 1ULL << 0, // "no-writes" flag kScriptAllowOom = 1ULL << 1, // "allow-oom" flag @@ -110,98 +118,19 @@ enum ScriptFlags : uint64_t { kScriptAllowCrossSlotKeys = 1ULL << 4, // "allow-cross-slot-keys" flag }; -class ShebangParser { - public: - ShebangParser(const std::string &shebang) : shebang_(shebang) {} - - [[nodiscard]] Status Parse() { - std::cout << "Start Parse\n"; - static constexpr const char *shebang_prefix = "#!lua"; - static constexpr const char *shebang_libname_prefix = "name="; - static constexpr const char *shebang_flags_prefix = "flags="; - - if (!util::HasPrefix(shebang_, shebang_prefix)) { - return {Status::NotOK, "Expect shebang prefix \"#!lua\" at the beginning of the first line"}; - } - auto shebang_content = shebang_.substr(strlen(shebang_prefix)); - for (const auto &shebang_split : util::Split(shebang_content, " ")) { - std::cout << shebang_split << std::endl; - if (util::HasPrefix(shebang_split, shebang_libname_prefix)) { - if (!libname_.empty()) { - // TODO 已经有了 - return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; - } - libname_ = shebang_split.substr(strlen(shebang_libname_prefix)); - if (libname_.empty() || - std::any_of(libname_.begin(), libname_.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { - return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; - } - } else if (util::HasPrefix(shebang_split, shebang_flags_prefix)) { - auto flags = shebang_split.substr(strlen(shebang_flags_prefix)); - for (const auto &flag : util::Split(flags, ",")) { - if (flag == "no-writes") { - flags_ |= kScriptNoWrites; - } else if (flag == "allow-oom") { - return {Status::NotSupported, "allow-oom is not supported yet"}; - } else if (flag == "allow-stale") { - return {Status::NotSupported, "allow-stale is not supported yet"}; - } else if (flag == "no-cluster") { - flags_ |= kScriptNoCluster; - } else if (flag == "allow-cross-slot-keys") { - flags_ |= kScriptAllowCrossSlotKeys; - } else { - return {Status::NotOK, "Unexpected flag in script shebang: " + flag}; - } - } - } else { - return {Status::NotOK, "Expect a valid Shebang statement"}; - } - } - return Status::OK(); - } - - [[nodiscard]] uint64_t GetFlags() const { return flags_; } - [[nodiscard]] std::string GetLibName() const { return libname_; } - - private: - uint64_t flags_ = 0; - std::string libname_; - std::string shebang_; -}; +[[nodiscard]] Status ExtractLibNameFromShebang(const std::string &shebang, std::string &libname); +[[nodiscard]] Status ExtractFlagsFromShebang(const std::string &shebang, uint64_t &flags); -inline constexpr const char *REGISTRY_SCRIPT_RUN_CTX_NAME = "SCRIPT_RUN_CTX"; -// TODO 注释 +/// ScriptRunCtx is used to record context information during the running of Eval scripts and functions. struct ScriptRunCtx { + // ScriptFlags uint64_t flags = 0; + // current_slot tracks the slot currently accessed by the script + // and is used to detect whether there is cross-slot access + // between multiple commands in a script or function. int current_slot = -1; }; -static void stackDump(lua_State *L) { - int top = lua_gettop(L); - for (auto i = top; i >= 1; i--) { /* repeat for each level */ - int t = lua_type(L, i); - printf("%d: ", i); - switch (t) { - case LUA_TSTRING: /* strings */ - printf("`%s'", lua_tostring(L, i)); - break; - - case LUA_TBOOLEAN: /* booleans */ - printf(lua_toboolean(L, i) ? "true" : "false"); - break; - - case LUA_TNUMBER: /* numbers */ - printf("%g", lua_tonumber(L, i)); - break; - default: /* other values */ - printf("%s", lua_typename(L, t)); - break; - } - printf("\n"); /* put a separator */ - } - printf("\n"); /* end the listing */ -} - template void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) { lua_pushstring(lua, name); @@ -219,17 +148,18 @@ T *GetFromRegistry(lua_State *lua, const char *name) { lua_gettable(lua, LUA_REGISTRYINDEX); if (lua_isnil(lua, -1)) { - lua_pop(lua, 1); /* pops the value */ + // pops the value + lua_pop(lua, 1); return nullptr; } - /* must be light user data */ + // must be light user data CHECK(lua_islightuserdata(lua, -1)); auto *ptr = static_cast(lua_touserdata(lua, -1)); CHECK_NOTNULL(ptr); - /* pops the value */ + // pops the value lua_pop(lua, 1); return ptr; diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index d0c12ec157a..47aa1e1d00f 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -22,6 +22,7 @@ package scripting import ( "context" _ "embed" + "fmt" "strings" "testing" @@ -116,10 +117,10 @@ var testFunctions = func(t *testing.T, enabledRESP3 string) { t.Run("FUNCTION LOAD errors", func(t *testing.T) { code := strings.Join(strings.Split(luaMylib1, "\n")[1:], "\n") - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), ".*Shebang statement.*") + util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), "ERR Expect shebang prefix \"#!lua\" at the beginning of the first line") code2 := "#!lua\n" + code - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), ".*Expect library name.*") + util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Expect a library name in the Shebang statement") code2 = "#!lua name=$$$\n" + code util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), ".*valid library name.*") @@ -277,3 +278,215 @@ var testFunctions = func(t *testing.T, enabledRESP3 string) { }, decodeListLibResult(t, r)) }) } + +func TestFunctionScriptFlags(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("no-writes", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=nowriteslib + redis.register_function('default_flag_func', function(keys, args) return redis.call("set", keys[1], args[1]) end) + redis.register_function('no_writes_func', function(keys, args) return redis.call("set", keys[1], args[1]) end, { 'no-writes' })`) + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "FCALL", "default_flag_func", 1, "k1", "v1") + require.NoError(t, r.Err()) + r = rdb.Do(ctx, "FCALL", "no_writes_func", 1, "k2", "v2") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "FCALL_RO", "default_flag_func", 1, "k1", "v1") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + r = rdb.Do(ctx, "FCALL_RO", "no_writes_func", 1, "k2", "v2") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + }) + + srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + defer func() { srv0.Close() }() + id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err()) + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + srv1Alive := true + defer func() { + if srv1Alive { + srv1.Close() + } + }() + + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0, srv0.Host(), srv0.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1, srv1.Host(), srv1.Port()) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + // Node0: bar-slot = 5061, test-slot = 6918 + // Node1: foo-slot = 12182 + // Different slots of different nodes are not affected by allow-cross-slot-keys, + // and different slots of the same node can be allowed + require.NoError(t, rdb0.Set(ctx, "bar", "bar_value", 0).Err()) + require.NoError(t, rdb0.Set(ctx, "test", "test_value", 0).Err()) + require.NoError(t, rdb1.Set(ctx, "foo", "foo_value", 0).Err()) + + t.Run("no-cluster", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=noclusterlib + redis.register_function('default_flag_func', function(keys) return redis.call('get', keys[1]) end) + redis.register_function('no_cluster_func', function(keys) return redis.call('get', keys[1]) end, { 'no-cluster' })`) + require.NoError(t, r.Err()) + + require.NoError(t, rdb0.Do(ctx, "FCALL", "default_flag_func", 1, "bar").Err()) + + r = rdb0.Do(ctx, "FCALL", "no_cluster_func", 1, "bar") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // Only valid in cluster mode + require.NoError(t, rdb.Set(ctx, "bar", "rdb_bar_value", 0).Err()) + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=noclusterlib + redis.register_function('no_cluster_func', function(keys) return redis.call('get', keys[1]) end, { 'no-cluster' })`) + require.NoError(t, r.Err()) + require.NoError(t, rdb.Do(ctx, "FCALL", "no_cluster_func", 1, "bar").Err()) + }) + + t.Run("allow-cross-slot-keys", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=crossslotlib + redis.register_function('default_flag_func_1', + function() + redis.call('get', 'bar') + return redis.call('get', 'test') + end + ) + + redis.register_function('default_flag_func_2', + function() + redis.call('get', 'bar') + return redis.call('get', 'foo') + end + ) + + redis.register_function('default_flag_func_3', + function(keys) + redis.call('get', keys[1]) + return redis.call('get', keys[2]) + end + ) + + redis.register_function( + 'allow_cross_slot_keys_func_1', + function() + redis.call('get', 'bar') + return redis.call('get', 'test') + end, + { 'allow-cross-slot-keys' }) + + redis.register_function( + 'allow_cross_slot_keys_func_2', + function() + redis.call('get', 'bar') + return redis.call('get', 'foo') + end, + { 'allow-cross-slot-keys' }) + + redis.register_function( + 'allow_cross_slot_keys_func_3', + function(keys) + redis.call('get', key[1]) + return redis.call('get', key[2]) + end, + { 'allow-cross-slot-keys' }) + + `) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_1", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access keys that do not hash to the same slot") + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_3", 2, "bar", "test") + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_1", 0) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + // Pre-declared keys are not affected by allow-cross-slot-keys + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_3", 2, "bar", "test") + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + }) + + t.Run("invalid-flags", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=invalidflagslib + redis.register_function('invalid_flag_func', function() end, { 'invalid-flag' })`) + require.NoError(t, r.Err()) + // Check whether the flag is valid only during FCALL, not during FUNCTION LOAD + r = rdb.Do(ctx, "FCALL", "invalid_flag_func", 0) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected function flag:*") + }) + + t.Run("mixed-use", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mixeduselib + redis.register_function('no_write_cluster_func_1', function() redis.call('get', 'bar') end, { 'no-writes', 'no-cluster' }) + + redis.register_function('no_write_allow_cross_func_1', + function() redis.call('get', 'bar'); return redis.call('get', 'test'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + + redis.register_function('no_write_allow_cross_func_2', + function() redis.call('set', 'bar'); return redis.call('set', 'test'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + + redis.register_function('no_write_allow_cross_func_3', + function() redis.call('get', 'bar'); return redis.call('get', 'foo'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + `) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "no_write_cluster_func_1", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // no-cluster Only valid in cluster mode + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mixeduselib2 + redis.register_function('no_write_cluster_func_2', + function() return redis.call('set', 'bar', 'bar_value') end, + { 'no-writes', 'no-cluster' } + ) + + redis.register_function('no_write_cluster_func_3', + function() return redis.call('get', 'bar') end, + { 'no-writes', 'no-cluster' } + ) + `) + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "FCALL", "no_write_cluster_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + require.NoError(t, rdb.Set(ctx, "bar", "bar_value_rdb", 0).Err()) + require.NoError(t, rdb.Do(ctx, "FCALL", "no_write_cluster_func_3", 0).Err()) + + require.NoError(t, rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_1", 0).Err()) + r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_3", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + }) +} diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index 497834b0753..c0d4517eb2b 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -480,6 +480,7 @@ math.randomseed(ARGV[1]); return tostring(math.random()) t.Run("EVALSHA_RO - cannot run write commands", func(t *testing.T) { require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err()) + // sha1 of `redis.call('del', KEYS[1]);` r := rdb.Do(ctx, "EVALSHA_RO", "a1e63e1cd1bd1d5413851949332cfb9da4ee6dc0", "1", "foo") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") }) @@ -708,7 +709,7 @@ func TestEvalScriptFlags(t *testing.T) { `#!lua flags=allow-cross-slot-keys redis.call('set', 'foo','value_foo'); return redis.call('set', 'bar', 'value_bar');`, "0") - util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") // There is a shebang prefix #!lua but crossslot is not allowed when flags are not set r = rdb0.Do(ctx, "EVAL", @@ -721,7 +722,7 @@ func TestEvalScriptFlags(t *testing.T) { `#!lua redis.call('get', 'foo'); return redis.call('get', 'bar');`, "0") - util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") // Old style: CrossSlot is allowed when there is neither #!lua nor flags set r = rdb0.Do(ctx, "EVAL", @@ -732,9 +733,9 @@ func TestEvalScriptFlags(t *testing.T) { r = rdb0.Do(ctx, "EVAL", `redis.call('get', 'foo'); return redis.call('get', 'bar');`, "0") - util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") - // Pre-declared keys are not affected by Arlo-Cross-Slot-Keyes + // Pre-declared keys are not affected by allow-cross-slot-keys r = rdb0.Do(ctx, "EVAL", `#!lua flags=allow-cross-slot-keys local key = redis.call('get', KEY[1]); @@ -783,6 +784,7 @@ func TestEvalScriptFlags(t *testing.T) { `#!lua flags=no-writes,allow-cross-slot-keys redis.call('get', 'bar'); return redis.call('get', 'foo');`, "0") - util.ErrorRegexp(t, r.Err(), "ERR .* MOVED *") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + }) } From 70296fa0a102d3d410f1765ca714a39b7e14e01f Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Thu, 25 Jul 2024 15:10:41 +0800 Subject: [PATCH 04/13] fix: typo --- src/storage/scripting.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/scripting.h b/src/storage/scripting.h index d789aba7ad2..8fc294c8e1f 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -33,7 +33,7 @@ inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_re inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr"; inline constexpr const char REDIS_FUNCTION_LIBNAME[] = "REDIS_FUNCTION_LIBNAME"; inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = "REDIS_FUNCTION_NEEDSTORE"; -inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LxIBRARIES"; +inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARIES"; inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX"; namespace lua { From 8db1c3a9288ee02d479a8261b3921dbc8e2a89f7 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Thu, 25 Jul 2024 16:11:16 +0800 Subject: [PATCH 05/13] fix: remove meaningless modification --- src/cluster/cluster.cc | 1 + src/storage/scripting.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index 5ce23954b20..b2b04d1add4 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -826,6 +826,7 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const { Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector keys_indexes; + // No keys if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK()) return Status::OK(); diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 1c1b4f74aec..7a0e81682fc 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -392,7 +392,6 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: uint64_t function_flags = read_only ? ScriptFlags::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); - // script, myfunc, err_func, table if (!lua_isnil(lua, -1)) { int n = static_cast(lua_objlen(lua, -1)); for (int i = 1; i <= n; ++i) { @@ -657,6 +656,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh * EVAL received. */ SetGlobalArray(lua, "KEYS", keys); SetGlobalArray(lua, "ARGV", argv); + if (lua_pcall(lua, 0, 1, -2)) { auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1)); *output = redis::Error({Status::NotOK, msg}); From 2faae056c148938cf293d48df41ae7eb991a84ba Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Thu, 25 Jul 2024 16:12:02 +0800 Subject: [PATCH 06/13] style: clang-format --- src/cluster/cluster.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index b2b04d1add4..bcf13732525 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -826,7 +826,7 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const { Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector keys_indexes; - + // No keys if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK()) return Status::OK(); From 236ec7a3257a5b03d6006e06b672f64f5337bdcf Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Thu, 25 Jul 2024 16:12:46 +0800 Subject: [PATCH 07/13] style: clang-format --- src/storage/scripting.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 7a0e81682fc..26edda06962 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -656,7 +656,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh * EVAL received. */ SetGlobalArray(lua, "KEYS", keys); SetGlobalArray(lua, "ARGV", argv); - + if (lua_pcall(lua, 0, 1, -2)) { auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1)); *output = redis::Error({Status::NotOK, msg}); From 2aa849692df0246fdfa1840817c61c5fd036c204 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 26 Jul 2024 18:36:21 +0800 Subject: [PATCH 08/13] refactor: better code --- src/cluster/cluster.cc | 2 +- src/server/server.cc | 11 - src/server/server.h | 4 - src/storage/scripting.cc | 218 ++++++++++++------ src/storage/scripting.h | 28 ++- tests/gocase/unit/scripting/function_test.go | 65 ++++-- tests/gocase/unit/scripting/scripting_test.go | 91 +++++++- 7 files changed, 303 insertions(+), 116 deletions(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index bcf13732525..6a90d39df3b 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -855,7 +855,7 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons if (getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; } - if (!(script_run_ctx->flags & lua::ScriptFlags::kScriptAllowCrossSlotKeys)) { + if (!(script_run_ctx->flags & lua::ScriptFlagType::kScriptAllowCrossSlotKeys)) { return {Status::RedisCrossSlot, "Script attempted to access keys that do not hash to the same slot"}; } } diff --git a/src/server/server.cc b/src/server/server.cc index 2b6bb581861..5e3b7f1363c 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -1682,16 +1682,6 @@ Status Server::ScriptSet(const std::string &sha, const std::string &body) const return storage->WriteToPropagateCF(func_name, body); } -void Server::CacheScriptFlags(const std::string &sha, uint64_t flags) { script_flags_cache_.try_emplace(sha, flags); } - -[[nodiscard]] Status Server::GetScriptFlags(const std::string &sha, uint64_t &flags) const { - if (script_flags_cache_.count(sha)) { - flags = script_flags_cache_.at(sha); - return Status::OK(); - } - return {Status::NotFound, "The flags cache of script sha does not exist, sha: " + sha}; -} - Status Server::FunctionGetCode(const std::string &lib, std::string *code) const { std::string func_name = engine::kLuaLibCodePrefix + lib; auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); @@ -1724,7 +1714,6 @@ Status Server::FunctionSetLib(const std::string &func, const std::string &lib) c void Server::ScriptReset() { auto lua = lua_.exchange(lua::CreateState(this)); - script_flags_cache_.clear(); lua::DestroyState(lua); } diff --git a/src/server/server.h b/src/server/server.h index 1b9b10a6464..c1793e81ce6 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -274,8 +274,6 @@ class Server { Status ScriptSet(const std::string &sha, const std::string &body) const; void ScriptReset(); Status ScriptFlush(); - void CacheScriptFlags(const std::string &sha, uint64_t flags); - [[nodiscard]] Status GetScriptFlags(const std::string &sha, uint64_t &flags) const; Status FunctionGetCode(const std::string &lib, std::string *code) const; Status FunctionGetLib(const std::string &func, std::string *lib) const; @@ -343,8 +341,6 @@ class Server { std::mutex last_random_key_cursor_mu_; std::atomic lua_; - // The cache of flag is cached when the script is created and cleared when the script is flushed. - std::unordered_map script_flags_cache_; redis::Connection *curr_connection_ = nullptr; diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 26edda06962..d6e2be909af 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -221,8 +221,8 @@ int RedisLogCommand(lua_State *lua) { int RedisRegisterFunction(lua_State *lua) { int argc = lua_gettop(lua); - if (argc < 2) { - lua_pushstring(lua, "redis.register_function() requires at least two arguments."); + if (argc < 2 || argc > 3) { + lua_pushstring(lua, "wrong number of arguments to redis.register_function()."); return lua_error(lua); } @@ -239,6 +239,13 @@ int RedisRegisterFunction(lua_State *lua) { // set this function to global std::string name = lua_tostring(lua, 1); if (argc == 3) { + auto flags = ExtractFlagsFromRegisterFunction(lua); + if (!flags) { + lua_pushstring(lua, flags.Msg().c_str()); + return lua_error(lua); + } + // lua does not support unsigned integers, so flags are stored as strings + lua_pushstring(lua, std::to_string(flags.GetValue()).c_str()); lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); } lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); @@ -285,8 +292,7 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee return {Status::NotOK, "Expect a Shebang statement in the first line"}; } - std::string libname; - if (auto s = ExtractLibNameFromShebang(first_line, libname); !s.IsOK()) return s; + const auto libname = GET_OR_RET(ExtractLibNameFromShebang(first_line)); auto srv = conn->GetServer(); auto lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -390,29 +396,10 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); } - uint64_t function_flags = read_only ? ScriptFlags::kScriptNoWrites : 0; + ScriptFlags function_flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); - if (!lua_isnil(lua, -1)) { - int n = static_cast(lua_objlen(lua, -1)); - for (int i = 1; i <= n; ++i) { - lua_pushnumber(lua, i); - lua_gettable(lua, -2); - std::string flag = lua_tostring(lua, -1); - if (flag == "no-writes") { - function_flags |= kScriptNoWrites; - } else if (flag == "allow-oom") { - return {Status::NotSupported, "allow-oom is not supported yet"}; - } else if (flag == "allow-stale") { - return {Status::NotSupported, "allow-stale is not supported yet"}; - } else if (flag == "no-cluster") { - function_flags |= kScriptNoCluster; - } else if (flag == "allow-cross-slot-keys") { - function_flags |= kScriptAllowCrossSlotKeys; - } else { - return {Status::NotOK, "Unexpected function flag: " + flag}; - } - lua_pop(lua, 1); - } + if (lua_isstring(lua, -1)) { + function_flags |= GET_OR_RET(ParseInt(lua_tostring(lua, -1))); } lua_pop(lua, 1); @@ -431,6 +418,22 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: lua_pop(lua, 2); } + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, nullptr); + + /* Call the Lua garbage collector from time to time to avoid a + * full cycle performed by Lua, which adds too latency. + * + * The call is performed every LUA_GC_CYCLE_PERIOD executed commands + * (and for LUA_GC_CYCLE_PERIOD collection steps) because calling it + * for every command uses too much CPU. */ + constexpr int64_t LUA_GC_CYCLE_PERIOD = 50; + static int64_t gc_count = 0; + + gc_count++; + if (gc_count == LUA_GC_CYCLE_PERIOD) { + lua_gc(lua, LUA_GCSTEP, LUA_GC_CYCLE_PERIOD); + gc_count = 0; + } return Status::OK(); } @@ -641,9 +644,12 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh } ScriptRunCtx current_script_run_ctx; - auto s = srv->GetScriptFlags(funcname + 2, current_script_run_ctx.flags); - if (!s.IsOK()) return s; - if (read_only) current_script_run_ctx.flags |= ScriptFlags::kScriptNoWrites; + lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); + if (lua_isstring(lua, -1)) { + current_script_run_ctx.flags = GET_OR_RET(ParseInt(lua_tostring(lua, -1))); + } + lua_pop(lua, 1); + if (read_only) current_script_run_ctx.flags |= ScriptFlagType::kScriptNoWrites; SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, ¤t_script_run_ctx); // For the Lua script, should be always run with RESP2 protocol, @@ -748,7 +754,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { auto attributes = cmd->GetAttributes(); auto cmd_flags = attributes->GenerateFlags(args); - if ((script_run_ctx->flags & ScriptFlags::kScriptNoWrites) && !(cmd_flags & redis::kCmdReadOnly)) { + if ((script_run_ctx->flags & ScriptFlagType::kScriptNoWrites) && !(cmd_flags & redis::kCmdReadOnly)) { PushError(lua, "Write commands are not allowed from read-only scripts"); return raise_error ? RaiseError(lua) : 1; } @@ -770,7 +776,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { redis::Connection *conn = srv->GetCurrentConnection(); if (config->cluster_enabled) { - if (script_run_ctx->flags & ScriptFlags::kScriptNoCluster) { + if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) { PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set"); return raise_error ? RaiseError(lua) : 1; } @@ -1489,72 +1495,132 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu lua_setglobal(lua, funcname); // Cache the flags of the current script - uint64_t script_flags = 0; + ScriptFlags script_flags = 0; if (auto pos = body.find('\n'); pos != std::string::npos) { auto first_line = body.substr(0, pos); - if (util::HasPrefix(first_line, "#!lua")) { - uint64_t shebang_flags = 0; - if (auto s = ExtractFlagsFromShebang(first_line, shebang_flags); !s.IsOK()) { - lua_pop(lua, 1); /* remove the error handler from the stack. */ - return s; - } - script_flags |= shebang_flags; - } else { - // scripts without #! can run commands that access keys belonging to different cluster hash slots, - // but ones with #! inherit the default flags, so they cannot. - script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys; - } + script_flags = GET_OR_RET(ExtractFlagsFromShebang(first_line)); + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots + script_flags = kScriptAllowCrossSlotKeys; } - srv->CacheScriptFlags(*sha, script_flags); + lua_pushstring(lua, std::to_string(script_flags).c_str()); + lua_setglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, *sha).c_str()); // would store lua function into propagate column family and propagate those scripts to slaves return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK(); } -[[nodiscard]] Status ExtractLibNameFromShebang(const std::string &shebang, std::string &libname) { - static constexpr const char *shebang_prefix = "#!lua"; - static constexpr const char *shebang_libname_prefix = "name="; +[[nodiscard]] StatusOr ExtractLibNameFromShebang(std::string_view shebang) { + static constexpr std::string_view lua_shebang_prefix = "#!lua"; + static constexpr std::string_view shebang_libname_prefix = "name="; + + if (shebang.substr(0, 2) != "#!") { + return {Status::NotOK, "Missing library meta"}; + } - if (!util::HasPrefix(shebang, shebang_prefix)) { - return {Status::NotOK, "Expect shebang prefix \"#!lua\" at the beginning of the first line"}; + auto shebang_splits = util::Split(shebang, " "); + if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) { + return {Status::NotOK, "Unexpected engine in script shebang: " + shebang_splits[0]}; } - if (auto pos = shebang.find(shebang_libname_prefix, strlen(shebang_prefix)); pos != std::string::npos) { - libname = shebang.substr(pos + strlen(shebang_libname_prefix)); + std::string libname; + bool found_libname = false; + for (size_t i = 1; i < shebang_splits.size(); i++) { + std::string_view shebang_split_sv = shebang_splits[i]; + if (shebang_split_sv.substr(0, shebang_libname_prefix.size()) != shebang_libname_prefix) { + return {Status::NotOK, "Unknown lua shebang option: " + shebang_splits[i]}; + } + if (found_libname) { + return {Status::NotOK, "Redundant library name in script shebang"}; + } + + libname = shebang_split_sv.substr(shebang_libname_prefix.size()); if (libname.empty() || std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { - return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; + return { + Status::NotOK, + "Library names can only contain letters, numbers, or underscores(_) and must be at least one character long"}; } - return Status::OK(); + found_libname = true; } - return {Status::NotOK, "Expect a library name in the Shebang statement"}; + if (found_libname) return libname; + return {Status::NotOK, "Library name was not given"}; } -[[nodiscard]] Status ExtractFlagsFromShebang(const std::string &shebang, uint64_t &flags) { - static constexpr const char *shebang_prefix = "#!lua"; - static constexpr const char *shebang_flags_prefix = "flags="; - - if (auto pos = shebang.find(shebang_flags_prefix, strlen(shebang_prefix)); pos != std::string::npos) { - auto flags_content = shebang.substr(pos + strlen(shebang_flags_prefix)); - flags = 0; - for (const auto &flag : util::Split(flags_content, ",")) { - if (flag == "no-writes") { - flags |= kScriptNoWrites; - } else if (flag == "allow-oom") { - return {Status::NotSupported, "allow-oom is not supported yet"}; - } else if (flag == "allow-stale") { - return {Status::NotSupported, "allow-stale is not supported yet"}; - } else if (flag == "no-cluster") { - flags |= kScriptNoCluster; - } else if (flag == "allow-cross-slot-keys") { - flags |= kScriptAllowCrossSlotKeys; - } else { - return {Status::NotOK, "Unexpected flag in script shebang: " + flag}; +[[nodiscard]] StatusOr GetFlagsFromStrings(const std::vector &flags_content) { + ScriptFlags flags = 0; + for (const auto &flag : flags_content) { + if (flag == "no-writes") { + flags |= kScriptNoWrites; + } else if (flag == "allow-oom") { + return {Status::NotSupported, "allow-oom is not supported yet"}; + } else if (flag == "allow-stale") { + return {Status::NotSupported, "allow-stale is not supported yet"}; + } else if (flag == "no-cluster") { + flags |= kScriptNoCluster; + } else if (flag == "allow-cross-slot-keys") { + flags |= kScriptAllowCrossSlotKeys; + } else { + return {Status::NotOK, "Unknown flag given: " + flag}; + } + } + return flags; +} + +[[nodiscard]] StatusOr ExtractFlagsFromShebang(std::string_view shebang) { + static constexpr std::string_view lua_shebang_prefix = "#!lua"; + static constexpr std::string_view shebang_flags_prefix = "flags="; + + ScriptFlags result_flags = 0; + if (shebang.substr(0, 2) == "#!") { + auto shebang_splits = util::Split(shebang, " "); + if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) { + return {Status::NotOK, "Unexpected engine in script shebang: " + shebang_splits[0]}; + } + bool found_flags = false; + for (size_t i = 1; i < shebang_splits.size(); i++) { + std::string_view shebang_split_sv = shebang_splits[i]; + if (shebang_split_sv.substr(0, shebang_flags_prefix.size()) != shebang_flags_prefix) { + return {Status::NotOK, "Unknown lua shebang option: " + shebang_splits[i]}; } + if (found_flags) { + return {Status::NotOK, "Redundant flags in script shebang"}; + } + auto flags_content = util::Split(shebang_split_sv.substr(shebang_flags_prefix.size()), ","); + result_flags |= GET_OR_RET(GetFlagsFromStrings(flags_content)); + found_flags = true; } + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + result_flags = kScriptAllowCrossSlotKeys; } - return Status::OK(); + + return result_flags; +} + +[[nodiscard]] StatusOr ExtractFlagsFromRegisterFunction(lua_State *lua) { + if (!lua_istable(lua, -1)) { + return {Status::NotOK, "Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' }"}; + } + auto flag_count = static_cast(lua_objlen(lua, -1)); + std::vector flags_content; + flags_content.reserve(flag_count); + for (int i = 1; i <= flag_count; ++i) { + lua_pushnumber(lua, i); + lua_gettable(lua, -2); + if (!lua_isstring(lua, -1)) { + return {Status::NotOK, "Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' }"}; + } + flags_content.emplace_back(lua_tostring(lua, -1)); + // pop up the current flag + lua_pop(lua, 1); + } + // pop up the corresponding table of the flags parameter + lua_pop(lua, 1); + + return GetFlagsFromStrings(flags_content); } } // namespace lua diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 8fc294c8e1f..0cfb8b5f634 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -28,6 +28,7 @@ #include "status.h" inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_"; +inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = "__redis_registered_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_registered_flags_"; inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr"; @@ -103,14 +104,14 @@ void SHA1Hex(char *digest, const char *script, size_t len); int RedisMathRandom(lua_State *l); int RedisMathRandomSeed(lua_State *l); -/// ScriptFlags turn on/off constraints or indicate properties in Eval scripts and functions +/// ScriptFlagType turn on/off constraints or indicate properties in Eval scripts and functions /// /// Note: The default for Eval scripts are different than the default for functions(default is 0). /// As soon as Redis sees the #! comment, it'll treat the script as if it declares flags, even if no flags are defined, /// it still has a different set of defaults compared to a script without a #! line. /// Another difference is that scripts without #! can run commands that access keys belonging to different cluster hash /// slots, but ones with #! inherit the default flags, so they cannot. -enum ScriptFlags : uint64_t { +enum ScriptFlagType : uint64_t { kScriptNoWrites = 1ULL << 0, // "no-writes" flag kScriptAllowOom = 1ULL << 1, // "allow-oom" flag kScriptAllowStale = 1ULL << 2, // "allow-stale" flag @@ -118,8 +119,24 @@ enum ScriptFlags : uint64_t { kScriptAllowCrossSlotKeys = 1ULL << 4, // "allow-cross-slot-keys" flag }; -[[nodiscard]] Status ExtractLibNameFromShebang(const std::string &shebang, std::string &libname); -[[nodiscard]] Status ExtractFlagsFromShebang(const std::string &shebang, uint64_t &flags); +/// ScriptFlags is composed of one or more ScriptFlagTypes combined by an OR operation +/// For example, ScriptFlags flags = kScriptNoWrites | kScriptNoCluster +using ScriptFlags = uint64_t; + +[[nodiscard]] StatusOr ExtractLibNameFromShebang(std::string_view shebang); +[[nodiscard]] StatusOr ExtractFlagsFromShebang(std::string_view shebang); + +/// GetFlagsFromStrings gets flags from flags_content and composites them together. +/// Each element in flags_content should correspond to a string form of ScriptFlagType +[[nodiscard]] StatusOr GetFlagsFromStrings(const std::vector &flags_content); + +/// ExtractFlagsFromRegisterFunction extracts the flags from the redis.register_function +/// +/// Note: When using it, you should make sure that +/// the top of the stack of lua is the flags parameter of redis.register_function. +/// The flags parameter in Lua is a table that stores strings. +/// After use, the original flags table on the top of the stack will be popped. +[[nodiscard]] StatusOr ExtractFlagsFromRegisterFunction(lua_State *lua); /// ScriptRunCtx is used to record context information during the running of Eval scripts and functions. struct ScriptRunCtx { @@ -131,6 +148,9 @@ struct ScriptRunCtx { int current_slot = -1; }; +/// SaveOnRegistry saves user-defined data to lua REGISTRY +/// +/// Note: Since lua_pushlightuserdata, you need to manage the life cycle of the data stored in the Registry yourself. template void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) { lua_pushstring(lua, name); diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index 47aa1e1d00f..8fa16037f21 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -117,13 +117,13 @@ var testFunctions = func(t *testing.T, enabledRESP3 string) { t.Run("FUNCTION LOAD errors", func(t *testing.T) { code := strings.Join(strings.Split(luaMylib1, "\n")[1:], "\n") - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), "ERR Expect shebang prefix \"#!lua\" at the beginning of the first line") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), "ERR Missing library meta") code2 := "#!lua\n" + code - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Expect a library name in the Shebang statement") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Library name was not given") code2 = "#!lua name=$$$\n" + code - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), ".*valid library name.*") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Library names can only contain letters, numbers, or underscores(_) and must be at least one character long") }) t.Run("FUNCTION LOAD and FCALL mylib1", func(t *testing.T) { @@ -287,6 +287,55 @@ func TestFunctionScriptFlags(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() + t.Run("Function extract-libname-error", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname flags=no-writes + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua flags=no-writes name=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname name=mylibname2 + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "Redundant library name in script shebang") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!errorenine name=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + }) + + t.Run("Function extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 'invalid-flag' })`) + require.Error(t, r.Err(), "ERR Error while running new function lib: Unknown flag given: invalid-flag") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 'no-writes', 'invalid-flag' })`) + require.Error(t, r.Err(), "ERR Error while running new function lib: Unknown flag given: invalid-flag") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { {} }`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 123 }`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, 'no-writes'`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + }) + t.Run("no-writes", func(t *testing.T) { r := rdb.Do(ctx, "FUNCTION", "LOAD", `#!lua name=nowriteslib @@ -430,16 +479,6 @@ func TestFunctionScriptFlags(t *testing.T) { require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") }) - t.Run("invalid-flags", func(t *testing.T) { - r := rdb.Do(ctx, "FUNCTION", "LOAD", - `#!lua name=invalidflagslib - redis.register_function('invalid_flag_func', function() end, { 'invalid-flag' })`) - require.NoError(t, r.Err()) - // Check whether the flag is valid only during FCALL, not during FUNCTION LOAD - r = rdb.Do(ctx, "FCALL", "invalid_flag_func", 0) - util.ErrorRegexp(t, r.Err(), "ERR Unexpected function flag:*") - }) - t.Run("mixed-use", func(t *testing.T) { r := rdb0.Do(ctx, "FUNCTION", "LOAD", `#!lua name=mixeduselib diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index c0d4517eb2b..bb3a4fae432 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -617,6 +617,90 @@ func TestEvalScriptFlags(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() + t.Run("Eval extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "EVAL", + `#!lua name=mylib + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes name=mylib + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua erroroption=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=invalid-flag + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,invalid-flag + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes no-cluster + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes flags=no-cluster + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script shebang") + + r = rdb.Do(ctx, "EVAL", + `#!errorengine flags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + }) + + t.Run("SCRIPT LOAD extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua name=mylib + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes name=mylib + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua erroroption=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=invalid-flag + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes,invalid-flag + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes no-cluster + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes flags=no-cluster + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script shebang") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!errorengine flags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + }) + t.Run("no-writes", func(t *testing.T) { r := rdb.Do(ctx, "EVAL", `#!lua flags=no-writes @@ -743,13 +827,6 @@ func TestEvalScriptFlags(t *testing.T) { require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") }) - t.Run("invalid-flags", func(t *testing.T) { - r := rdb0.Do(ctx, "EVAL", - `#!lua flags=invalid-flag - return redis.call('set', 'k','v');`, "0") - util.ErrorRegexp(t, r.Err(), "ERR Unexpected flag in script shebang:*") - }) - t.Run("mixed use", func(t *testing.T) { r := rdb0.Do(ctx, "EVAL", `#!lua flags=no-writes,no-cluster From b6e2d33c8a5855e9762205c056f61add3434205b Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Fri, 26 Jul 2024 18:59:38 +0800 Subject: [PATCH 09/13] test: more flags and libname test cases --- tests/gocase/unit/scripting/function_test.go | 15 ++++++++++++++ tests/gocase/unit/scripting/scripting_test.go | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index 8fa16037f21..8d71158498a 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -307,6 +307,21 @@ func TestFunctionScriptFlags(t *testing.T) { `#!errorenine name=mylibname redis.register_function('extract_libname_error_func', function(keys, args) end)`) util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!luaname=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua xxxname=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname key=value + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") }) t.Run("Function extract-flags-error", func(t *testing.T) { diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index bb3a4fae432..9e4f275239a 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -657,6 +657,16 @@ func TestEvalScriptFlags(t *testing.T) { `#!errorengine flags=no-writes return 'extract-flags'`, "0") util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "EVAL", + `#!luaflags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua xxflags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") }) t.Run("SCRIPT LOAD extract-flags-error", func(t *testing.T) { @@ -699,6 +709,16 @@ func TestEvalScriptFlags(t *testing.T) { `#!errorengine flags=no-writes return 'extract-flags'`) util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!luaflags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua xxflags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") }) t.Run("no-writes", func(t *testing.T) { From 972ef5bd25460a7b4f53d9f6c14737b025853ba6 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 27 Jul 2024 00:26:01 +0800 Subject: [PATCH 10/13] refactor: store flag as string => integer --- src/storage/scripting.cc | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index d6e2be909af..632a6e56f95 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -244,8 +244,7 @@ int RedisRegisterFunction(lua_State *lua) { lua_pushstring(lua, flags.Msg().c_str()); return lua_error(lua); } - // lua does not support unsigned integers, so flags are stored as strings - lua_pushstring(lua, std::to_string(flags.GetValue()).c_str()); + lua_pushinteger(lua, static_cast(flags.GetValue())); lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); } lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); @@ -396,15 +395,17 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); } - ScriptFlags function_flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; + ScriptRunCtx script_run_ctx; + script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); - if (lua_isstring(lua, -1)) { - function_flags |= GET_OR_RET(ParseInt(lua_tostring(lua, -1))); + if (!lua_isnil(lua, -1)) { + int isnum = false; + auto function_flags = lua_tointegerx(lua, -1, &isnum); + if (!isnum) return {Status::NotOK, "Invalid function flags"}; + script_run_ctx.flags |= function_flags; } lua_pop(lua, 1); - ScriptRunCtx script_run_ctx; - script_run_ctx.flags = function_flags; SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); PushArray(lua, keys); @@ -644,12 +645,16 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh } ScriptRunCtx current_script_run_ctx; + current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); - if (lua_isstring(lua, -1)) { - current_script_run_ctx.flags = GET_OR_RET(ParseInt(lua_tostring(lua, -1))); + if (!lua_isnil(lua, -1)) { + int isnum = false; + auto script_flags = lua_tointegerx(lua, -1, &isnum); + if (!isnum) return {Status::NotOK, "Invalid Script flags"}; + current_script_run_ctx.flags |= script_flags; } lua_pop(lua, 1); - if (read_only) current_script_run_ctx.flags |= ScriptFlagType::kScriptNoWrites; + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, ¤t_script_run_ctx); // For the Lua script, should be always run with RESP2 protocol, @@ -1503,7 +1508,7 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu // scripts without #! can run commands that access keys belonging to different cluster hash slots script_flags = kScriptAllowCrossSlotKeys; } - lua_pushstring(lua, std::to_string(script_flags).c_str()); + lua_pushinteger(lua, static_cast(script_flags)); lua_setglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, *sha).c_str()); // would store lua function into propagate column family and propagate those scripts to slaves From c8a50bcdded59f9fb94a1189162622d89c87d612 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 10 Aug 2024 14:47:01 +0800 Subject: [PATCH 11/13] fix: lua_tointegerx => lua_tointeger --- src/storage/scripting.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 632a6e56f95..00f1bac134a 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -399,9 +399,7 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); if (!lua_isnil(lua, -1)) { - int isnum = false; - auto function_flags = lua_tointegerx(lua, -1, &isnum); - if (!isnum) return {Status::NotOK, "Invalid function flags"}; + auto function_flags = lua_tointeger(lua, -1); script_run_ctx.flags |= function_flags; } lua_pop(lua, 1); @@ -648,9 +646,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); if (!lua_isnil(lua, -1)) { - int isnum = false; - auto script_flags = lua_tointegerx(lua, -1, &isnum); - if (!isnum) return {Status::NotOK, "Invalid Script flags"}; + auto script_flags = lua_tointeger(lua, -1); current_script_run_ctx.flags |= script_flags; } lua_pop(lua, 1); From f2fced3e2cc333bba71ed6f2ff379a7ff55173c2 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Sat, 10 Aug 2024 18:41:02 +0800 Subject: [PATCH 12/13] fix: Compatible with Lua 5.1 --- src/storage/scripting.cc | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 00f1bac134a..3ce6d7adb4a 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -399,6 +399,7 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); if (!lua_isnil(lua, -1)) { + // It should be ensured that the conversion is successful auto function_flags = lua_tointeger(lua, -1); script_run_ctx.flags |= function_flags; } @@ -646,6 +647,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); if (!lua_isnil(lua, -1)) { + // It should be ensured that the conversion is successful auto script_flags = lua_tointeger(lua, -1); current_script_run_ctx.flags |= script_flags; } @@ -1488,17 +1490,14 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu std::copy(sha->begin(), sha->end(), funcname + 2); } - if (luaL_loadbuffer(lua, body.c_str(), body.size(), "@user_script")) { - std::string err_msg = lua_tostring(lua, -1); - lua_pop(lua, 1); - return {Status::NotOK, "Error while compiling new script: " + err_msg}; - } - lua_setglobal(lua, funcname); - + std::string_view lua_code(body); // Cache the flags of the current script ScriptFlags script_flags = 0; if (auto pos = body.find('\n'); pos != std::string::npos) { - auto first_line = body.substr(0, pos); + std::string_view first_line(body.data(), pos); + if (first_line.substr(0, 2) == "#!") { + lua_code = lua_code.substr(pos + 1); + } script_flags = GET_OR_RET(ExtractFlagsFromShebang(first_line)); } else { // scripts without #! can run commands that access keys belonging to different cluster hash slots @@ -1507,6 +1506,13 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu lua_pushinteger(lua, static_cast(script_flags)); lua_setglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, *sha).c_str()); + if (luaL_loadbuffer(lua, lua_code.data(), lua_code.size(), "@user_script")) { + std::string err_msg = lua_tostring(lua, -1); + lua_pop(lua, 1); + return {Status::NotOK, "Error while compiling new script: " + err_msg}; + } + lua_setglobal(lua, funcname); + // would store lua function into propagate column family and propagate those scripts to slaves return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK(); } From 6e474465ad7e6677131ddc2194cc079bd1af4f88 Mon Sep 17 00:00:00 2001 From: ZhouSiLe Date: Mon, 12 Aug 2024 17:58:04 +0800 Subject: [PATCH 13/13] refactor: RemoveFromRegistry --- src/storage/scripting.cc | 10 ++++++++-- src/storage/scripting.h | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 3ce6d7adb4a..bfe1016ee70 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -418,7 +418,7 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: lua_pop(lua, 2); } - SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, nullptr); + RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); /* Call the Lua garbage collector from time to time to avoid a * full cycle performed by Lua, which adds too latency. @@ -681,7 +681,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh lua_setglobal(lua, "KEYS"); lua_pushnil(lua); lua_setglobal(lua, "ARGV"); - SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, nullptr); + RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); /* Call the Lua garbage collector from time to time to avoid a * full cycle performed by Lua, which adds too latency. @@ -1630,4 +1630,10 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu return GetFlagsFromStrings(flags_content); } +void RemoveFromRegistry(lua_State *lua, const char *name) { + lua_pushstring(lua, name); + lua_pushnil(lua); + lua_settable(lua, LUA_REGISTRYINDEX); +} + } // namespace lua diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 0cfb8b5f634..6cfa31f066b 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -185,4 +185,6 @@ T *GetFromRegistry(lua_State *lua, const char *name) { return ptr; } +void RemoveFromRegistry(lua_State *lua, const char *name); + } // namespace lua