-
Notifications
You must be signed in to change notification settings - Fork 1k
fcase support scalar condition, vectorize default and lazy-eval default #4264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3334d0c
06ad48a
614aab5
cf59c87
73483aa
188170f
4252244
33d0bd9
ebbb0b1
dcab7d3
57980b0
6ab016a
5d7e464
17574fe
2fb2056
fe380ec
6a9b9c0
17cbe83
b6d3a84
f89dc3b
19202c8
28dd248
64774d5
afd5d15
e4a933f
743871f
ee1c58c
4766135
b2c9f2e
8afc9ce
b20e151
0ff96d2
9cce20b
fba3f0c
3a16d75
e7ff3c3
faa8020
54f6e5a
0b44ed5
906c253
14d9513
bae7748
7df308d
3ef7d9e
5b13476
3f1f109
c5caab8
39846df
358f78e
de5bbd7
a653106
ffe92fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -201,105 +201,106 @@ SEXP fifelseR(SEXP l, SEXP a, SEXP b, SEXP na) { | |
| return ans; | ||
| } | ||
|
|
||
| SEXP fcaseR(SEXP na, SEXP rho, SEXP args) { | ||
| const int narg=length(args); | ||
| SEXP fcaseR(SEXP rho, SEXP args) { | ||
| const int narg=length(args); // `default` will take the last two positions | ||
| if (narg % 2) { | ||
| error(_("Received %d inputs; please supply an even number of arguments in ..., " | ||
| "consisting of logical condition, resulting value pairs (in that order). " | ||
| "Note that the default argument must be named explicitly, e.g., default=0"), narg); | ||
| "Note that the default argument must be named explicitly, e.g., default=0"), narg - 2); | ||
| } | ||
| if (narg==0) return R_NilValue; | ||
|
|
||
| SEXP cons0 = PROTECT(eval(SEXPPTR_RO(args)[0], rho)); | ||
| SEXP value0 = PROTECT(eval(SEXPPTR_RO(args)[1], rho)); // value0 will be compared to from loop so leave it protected throughout | ||
| SEXPTYPE type0 = TYPEOF(value0); | ||
| int64_t len0=xlength(cons0), len2=len0; | ||
| if (isS4(value0) && !INHERITS(value0, char_nanotime)) { | ||
| error(_("S4 class objects (except nanotime) are not supported. Please see https://github.com/Rdatatable/data.table/issues/4131.")); | ||
| // otherwise 'invalid type/length (S4/1) in vector allocation' from test 2132.3 | ||
| } | ||
| SEXP ans = PROTECT(allocVector(type0, len0)); | ||
| SEXP tracker = PROTECT(allocVector(INTSXP, len0)); | ||
| int *restrict p = INTEGER(tracker); | ||
| copyMostAttrib(value0, ans); | ||
|
|
||
| bool nonna=!isNull(na); | ||
| if (nonna) { | ||
| if (xlength(na) != 1) { | ||
| error(_("Length of 'default' must be 1.")); | ||
| } | ||
| SEXPTYPE tn = TYPEOF(na); | ||
| if (tn==LGLSXP && LOGICAL(na)[0]==NA_LOGICAL) { | ||
| nonna = false; | ||
| } else { | ||
| if (tn != type0) { | ||
| error(_("Resulting value is of type %s but 'default' is of type %s. " | ||
| "Please make sure that both arguments have the same type."), type2char(type0), type2char(tn)); | ||
| } | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0,R_ClassSymbol)), PROTECT(getAttrib(na,R_ClassSymbol)), 0)) { | ||
| error(_("Resulting value has different class than 'default'. " | ||
| "Please make sure that both arguments have the same class.")); | ||
| } | ||
| UNPROTECT(2); | ||
| if (isFactor(value0)) { | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0,R_LevelsSymbol)), PROTECT(getAttrib(na,R_LevelsSymbol)), 0)) { | ||
| error(_("Resulting value and 'default' are both type factor but their levels are different.")); | ||
| } | ||
| UNPROTECT(2); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| int nprotect=0, l; | ||
| int64_t len0=0, len1=0, len2=0; | ||
| SEXP ans=R_NilValue, value0=R_NilValue, tracker=R_NilValue, whens=R_NilValue, thens=R_NilValue; | ||
| PROTECT_INDEX Iwhens, Ithens; | ||
| PROTECT_WITH_INDEX(whens, &Iwhens); nprotect++; | ||
| PROTECT_WITH_INDEX(thens, &Ithens); nprotect++; | ||
| SEXPTYPE type0=NILSXP; | ||
| // naout means if the output is scalar logic na | ||
| bool imask = true, naout = false, idefault = false; | ||
| int *restrict p = NULL; | ||
| const int n = narg/2; | ||
| for (int i=0; i<n; ++i) { | ||
| SEXP cons = PROTECT(i==0 ? cons0 : eval(SEXPPTR_RO(args)[2*i], rho)); // protect cons0 again for easy unprotect at the end of this loop | ||
| SEXP outs = PROTECT(i==0 ? value0 : eval(SEXPPTR_RO(args)[2*i+1], rho)); | ||
| if (isS4(outs) && !INHERITS(outs, char_nanotime)) { | ||
| idefault = i == (n - 1); // mark if the current eval is the `default` on R side | ||
| REPROTECT(whens = eval(SEXPPTR_RO(args)[2*i], rho), Iwhens); | ||
| REPROTECT(thens = eval(SEXPPTR_RO(args)[2*i+1], rho), Ithens); | ||
| if (isS4(thens) && !INHERITS(thens, char_nanotime)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is PROTECTion required in this branch (before erroring)?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to right? That's my read of WRE/what I'm used to elsewhere (also on Garbage Collection):
|
||
| error(_("S4 class objects (except nanotime) are not supported. Please see https://github.com/Rdatatable/data.table/issues/4131.")); | ||
| } | ||
| if (!isLogical(cons)) { | ||
| error(_("Argument #%d must be logical."), 2*i+1); | ||
| if (!isLogical(whens)) { | ||
|
TysonStanley marked this conversation as resolved.
|
||
| error(_("Argument #%d must be logical but was of type %s."), 2*i+1, type2char(TYPEOF(whens))); | ||
| } | ||
| if (i>0) { | ||
| if (xlength(cons) != len0) { | ||
| error(_("Argument #%d has a different length than argument #1. " | ||
| "Please make sure all logical conditions have the same length."), | ||
| i*2+1); | ||
| const int *restrict pwhens = LOGICAL(whens); | ||
| l = 0; | ||
| if (i == 0) { | ||
| len0 = xlength(whens); | ||
| len2 = len0; | ||
| type0 = TYPEOF(thens); | ||
| value0 = thens; | ||
| ans = PROTECT(allocVector(type0, len0)); nprotect++; | ||
| copyMostAttrib(thens, ans); | ||
| tracker = PROTECT(allocVector(INTSXP, len0)); nprotect++; | ||
| p = INTEGER(tracker); | ||
| } else { | ||
| imask = false; | ||
| naout = xlength(thens) == 1 && TYPEOF(thens) == LGLSXP && LOGICAL(thens)[0]==NA_LOGICAL; | ||
| if (xlength(whens) != len0 && xlength(whens) != 1) { | ||
| // no need to check `idefault` here because the con for default is always `TRUE` | ||
| error(_("Argument #%d has length %lld which differs from that of argument #1 (%lld). " | ||
| "Please make sure all logical conditions have the same length or length 1."), | ||
| i*2+1, (long long)xlength(whens), (long long)len0); | ||
| } | ||
| if (TYPEOF(outs) != type0) { | ||
| error(_("Argument #%d is of type %s, however argument #2 is of type %s. " | ||
| "Please make sure all output values have the same type."), | ||
| i*2+2, type2char(TYPEOF(outs)), type2char(type0)); | ||
| if (!naout && TYPEOF(thens) != type0) { | ||
| if (idefault) { | ||
| error(_("Resulting value is of type %s but 'default' is of type %s. " | ||
| "Please make sure that both arguments have the same type."), type2char(type0), type2char(TYPEOF(thens))); | ||
| } else { | ||
| error(_("Argument #%d is of type %s, however argument #2 is of type %s. " | ||
| "Please make sure all output values have the same type."), | ||
| i*2+2, type2char(TYPEOF(thens)), type2char(type0)); | ||
| } | ||
| } | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0,R_ClassSymbol)), PROTECT(getAttrib(outs,R_ClassSymbol)), 0)) { | ||
| error(_("Argument #%d has different class than argument #2, " | ||
| "Please make sure all output values have the same class."), i*2+2); | ||
| if (!naout) { | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0, R_ClassSymbol)), PROTECT(getAttrib(thens, R_ClassSymbol)), 0)) { | ||
| if (idefault) { | ||
| error(_("Resulting value has different class than 'default'. " | ||
| "Please make sure that both arguments have the same class.")); | ||
| } else { | ||
| error(_("Argument #%d has different class than argument #2, " | ||
| "Please make sure all output values have the same class."), i*2+2); | ||
| } | ||
| } | ||
| UNPROTECT(2); // class(value0), class(thens) | ||
| } | ||
| UNPROTECT(2); | ||
| if (isFactor(value0)) { | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0,R_LevelsSymbol)), PROTECT(getAttrib(outs,R_LevelsSymbol)), 0)) { | ||
| error(_("Argument #2 and argument #%d are both factor but their levels are different."), i*2+2); | ||
| if (!naout && isFactor(value0)) { | ||
| if (!R_compute_identical(PROTECT(getAttrib(value0, R_LevelsSymbol)), PROTECT(getAttrib(thens, R_LevelsSymbol)), 0)) { | ||
| if (idefault) { | ||
| error(_("Resulting value and 'default' are both type factor but their levels are different.")); | ||
| } else { | ||
| error(_("Argument #2 and argument #%d are both factor but their levels are different."), i*2+2); | ||
| } | ||
| } | ||
| UNPROTECT(2); | ||
| UNPROTECT(2); // levels(value0), levels(thens) | ||
| } | ||
| } | ||
| int64_t len1 = xlength(outs); | ||
| if (len1!=len0 && len1!=1) { | ||
| error(_("Length of output value #%d must either be 1 or length of logical condition."), i*2+2); | ||
| len1 = xlength(thens); | ||
| if (len1 != len0 && len1 != 1) { | ||
| if (idefault) { | ||
| error(_("Length of 'default' must be 1 or %lld."), (long long)len0); | ||
| } else { | ||
| error(_("Length of output value #%d (%lld) must either be 1 or match the length of the logical condition (%lld)."), i*2+2, (long long)len1, (long long)len0); | ||
| } | ||
| } | ||
| int64_t amask = len1>1 ? INT64_MAX : 0; | ||
| const int *restrict pcons = LOGICAL(cons); | ||
| const bool imask = i==0; | ||
| int64_t l=0; // how many this case didn't satisfy; i.e. left for next case | ||
| switch(TYPEOF(outs)) { | ||
| int64_t thenMask = len1>1 ? INT64_MAX : 0, whenMask = xlength(whens)>1 ? INT64_MAX : 0; | ||
| switch(TYPEOF(ans)) { | ||
| case LGLSXP: { | ||
| const int *restrict pouts = LOGICAL(outs); | ||
| const int *restrict pthens; | ||
| if (!naout) pthens = LOGICAL(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Applies to other cases too): Should the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had a similar thought about reorganizing the code... but probably better to explore as a follow-up. |
||
| int *restrict pans = LOGICAL(ans); | ||
| const int pna = nonna ? LOGICAL(na)[0] : NA_LOGICAL; | ||
| const int pna = NA_LOGICAL; | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| pans[idx] = pouts[idx & amask]; | ||
| if (pwhens[idx & whenMask]==1) { | ||
| pans[idx] = naout ? pna : pthens[idx & thenMask]; | ||
| } else { | ||
| if (imask) { | ||
| pans[j] = pna; | ||
|
|
@@ -309,13 +310,14 @@ SEXP fcaseR(SEXP na, SEXP rho, SEXP args) { | |
| } | ||
| } break; | ||
| case INTSXP: { | ||
| const int *restrict pouts = INTEGER(outs); | ||
| const int *restrict pthens; | ||
| if (!naout) pthens = INTEGER(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
| int *restrict pans = INTEGER(ans); | ||
| const int pna = nonna ? INTEGER(na)[0] : NA_INTEGER; | ||
| const int pna = NA_INTEGER; | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| pans[idx] = pouts[idx & amask]; | ||
| if (pwhens[idx & whenMask]==1) { | ||
| pans[idx] = naout ? pna : pthens[idx & thenMask]; | ||
| } else { | ||
| if (imask) { | ||
| pans[j] = pna; | ||
|
|
@@ -325,14 +327,15 @@ SEXP fcaseR(SEXP na, SEXP rho, SEXP args) { | |
| } | ||
| } break; | ||
| case REALSXP: { | ||
| const double *restrict pouts = REAL(outs); | ||
| const double *restrict pthens; | ||
| if (!naout) pthens = REAL(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
| double *restrict pans = REAL(ans); | ||
| const double na_double = INHERITS(outs, char_integer64) ? NA_INT64_D : NA_REAL; | ||
| const double pna = nonna ? REAL(na)[0] : na_double; | ||
| const double na_double = INHERITS(ans, char_integer64) ? NA_INT64_D : NA_REAL; | ||
| const double pna = na_double; | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| pans[idx] = pouts[idx & amask]; | ||
| if (pwhens[idx & whenMask]==1) { | ||
| pans[idx] = naout ? pna : pthens[idx & thenMask]; | ||
| } else { | ||
| if (imask) { | ||
| pans[j] = pna; | ||
|
|
@@ -342,13 +345,14 @@ SEXP fcaseR(SEXP na, SEXP rho, SEXP args) { | |
| } | ||
| } break; | ||
| case CPLXSXP: { | ||
| const Rcomplex *restrict pouts = COMPLEX(outs); | ||
| const Rcomplex *restrict pthens; | ||
| if (!naout) pthens = COMPLEX(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
| Rcomplex *restrict pans = COMPLEX(ans); | ||
| const Rcomplex pna = nonna ? COMPLEX(na)[0] : NA_CPLX; | ||
| const Rcomplex pna = NA_CPLX; | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| pans[idx] = pouts[idx & amask]; | ||
| if (pwhens[idx & whenMask]==1) { | ||
| pans[idx] = naout ? pna : pthens[idx & thenMask]; | ||
| } else { | ||
| if (imask) { | ||
| pans[j] = pna; | ||
|
|
@@ -358,44 +362,43 @@ SEXP fcaseR(SEXP na, SEXP rho, SEXP args) { | |
| } | ||
| } break; | ||
| case STRSXP: { | ||
| const SEXP *restrict pouts = STRING_PTR_RO(outs); | ||
| const SEXP pna = nonna ? STRING_PTR_RO(na)[0] : NA_STRING; | ||
| const SEXP *restrict pthens; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think it's fine, but OTOH, I can't tell whether |
||
| if (!naout) pthens = STRING_PTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
| const SEXP pna = NA_STRING; | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| SET_STRING_ELT(ans, idx, pouts[idx & amask]); | ||
| if (pwhens[idx & whenMask]==1) { | ||
| SET_STRING_ELT(ans, idx, naout ? pna : pthens[idx & thenMask]); | ||
| } else { | ||
| if (imask) { | ||
| SET_STRING_ELT(ans, idx, pna); | ||
| SET_STRING_ELT(ans, j, pna); | ||
| } | ||
| p[l++] = idx; | ||
| } | ||
| } | ||
| } break; | ||
| case VECSXP: { | ||
| const SEXP *restrict pouts = SEXPPTR_RO(outs); | ||
| const SEXP pna = SEXPPTR_RO(na)[0]; | ||
| // the default value of VECSXP is `NULL` so we don't need to explicitly | ||
| // assign the NA values as it does for other atomic types | ||
| const SEXP *restrict pthens; | ||
| if (!naout) pthens = SEXPPTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar | ||
| for (int64_t j=0; j<len2; ++j) { | ||
| const int64_t idx = imask ? j : p[j]; | ||
| if (pcons[idx]==1) { | ||
| SET_VECTOR_ELT(ans, idx, pouts[idx & amask]); | ||
| if (pwhens[idx & whenMask]==1) { | ||
| if (!naout) SET_VECTOR_ELT(ans, idx, pthens[idx & thenMask]); | ||
| } else { | ||
| if (imask && nonna) { | ||
| SET_VECTOR_ELT(ans, idx, pna); | ||
| } | ||
| p[l++] = idx; | ||
| } | ||
| } | ||
| } break; | ||
| default: | ||
| error(_("Type '%s' is not supported"), type2char(TYPEOF(outs))); | ||
| error(_("Type '%s' is not supported."), type2char(TYPEOF(ans))); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should mention what argument the "Type" is referring to.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing here is getting a consistent argument to make translation easier: So I would move this to a follow-up where we try and improve that message more generally, maybe as "Type '%s' is not supported in argument %s". |
||
| } | ||
| UNPROTECT(2); // this cons and outs | ||
| if (l==0) { | ||
| break; // stop early as nothing left to do | ||
| } | ||
| len2 = l; | ||
| } | ||
| UNPROTECT(4); // cons0, value0, ans, tracker | ||
| UNPROTECT(nprotect); // whens, thens, ans, tracker | ||
| return ans; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why
PROTECT_WITH_INDEXis being used. Could a comment be added explaining what these variables are and why this protection method was chosen?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The basic idea is to re-use memory as we go along, I tried naming these
when/thento evokeCASE WHEN A THEN B WHEN C THEN D .... SoPROTECT_INDEXis to make the memory forA,C,E, ... is in the same place, same forB,D,F, ...(AIUI)
I am not sure the benefits/tradeoffs here. The first commit in this branch was much simpler 3334d0c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still noob at
PROTECT_WITH_INDEX()so I turned to LLM for some explanations...https://g.co/gemini/share/85cdc98bb748

https://chatgpt.com/share/a4792977-f24a-47da-95c0-da48a1ad7c4c
Claude (permalinks not supported?)
The main themes are pretty consistent -- the alternative involves doing
UNPROTECT(2)and thenPROTECT()twice again in each iteration, hence potentially lots of RAM churn that we avoid withPROTECT_WITH_INDEX().