From 76eebc3b996cdddc12b54e70e9d017acf8543747 Mon Sep 17 00:00:00 2001 From: zhujch Date: Tue, 7 Nov 2023 18:00:52 -0500 Subject: [PATCH] Remove redundant frule definitions --- Project.toml | 3 ++- src/codegen.jl | 35 ++++++----------------------------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index e7690c0..06f2d0e 100644 --- a/Project.toml +++ b/Project.toml @@ -11,15 +11,16 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" +IrrationalConstants = "0.2" SliceMap = "0.2" SpecialFunctions = "2" -IrrationalConstants = "0.2" SymbolicUtils = "1" Zygote = "0.6.55" julia = "1.6" diff --git a/src/codegen.jl b/src/codegen.jl index 5e27b04..10107d6 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -1,36 +1,13 @@ +using ChainRules using ChainRulesCore using SpecialFunctions using IrrationalConstants: sqrtπ +using Symbolics: @variables using SymbolicUtils, SymbolicUtils.Code -using SymbolicUtils: BasicSymbolic, Pow - -@scalar_rule +(x::BasicSymbolic) true -@scalar_rule -(x::BasicSymbolic) -1 -@scalar_rule deg2rad(x::BasicSymbolic) deg2rad(one(x)) -@scalar_rule rad2deg(x::BasicSymbolic) rad2deg(one(x)) -@scalar_rule asin(x::BasicSymbolic) inv(sqrt(1 - x^2)) -@scalar_rule acos(x::BasicSymbolic) inv(-sqrt(1 - x^2)) -@scalar_rule atan(x::BasicSymbolic) inv(-(1 + x^2)) -@scalar_rule acot(x::BasicSymbolic) inv(-(1 + x^2)) -@scalar_rule acsc(x::BasicSymbolic) inv(x^2 * -sqrt(1 - x^-2)) -@scalar_rule asec(x::BasicSymbolic) inv(x^2 * sqrt(1 - x^-2)) -@scalar_rule log(x::BasicSymbolic) inv(x) -@scalar_rule log10(x::BasicSymbolic) inv(log(10.0) * x) -@scalar_rule log1p(x::BasicSymbolic) inv(x + 1) -@scalar_rule log2(x::BasicSymbolic) inv(log(2.0) * x) -@scalar_rule sinh(x::BasicSymbolic) cosh(x) -@scalar_rule cosh(x::BasicSymbolic) sinh(x) -@scalar_rule tanh(x::BasicSymbolic) 1-Ω^2 -@scalar_rule acosh(x::BasicSymbolic) inv(sqrt(x - 1) * sqrt(x + 1)) -@scalar_rule acoth(x::BasicSymbolic) inv(1 - x^2) -@scalar_rule acsch(x::BasicSymbolic) inv(x^2 * -sqrt(1 + x^-2)) -@scalar_rule asech(x::BasicSymbolic) inv(x * -sqrt(1 - x^2)) -@scalar_rule asinh(x::BasicSymbolic) inv(sqrt(x^2 + 1)) -@scalar_rule atanh(x::BasicSymbolic) inv(1 - x^2) -@scalar_rule erf(x::BasicSymbolic) exp(-x^2) * 2/sqrtπ +using SymbolicUtils: Pow dummy = (NoTangent(), 1) -@syms t₁ +@variables z for func in (+, -, deg2rad, rad2deg, sinh, cosh, tanh, asin, acos, atan, asec, acsc, acot, @@ -43,7 +20,7 @@ for func in (+, -, deg2rad, rad2deg, t0, t1 = value(t) TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0)) end - der = frule(dummy, func, t₁)[2] + der = frule(dummy, func, z)[2] term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise) # recursion by raising @eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N} @@ -51,7 +28,7 @@ for func in (+, -, deg2rad, rad2deg, f = $func quote $(Expr(:meta, :inline)) - t₁ = TaylorScalar{T, N - 1}(t) + z = TaylorScalar{T, N - 1}(t) df = $der_expr $$raiser($f(value(t)[1]), df, t) end