vshankar
vshankar

Reputation: 23

Scalar operations during backpropagation of neural ODE with DiffEqFlux on GPU

I'm using the DiffEqFlux package in Julia to implement a neural ODE. I'm having issues with getting it to work on GPU.

Simplest example is here:

using DifferentialEquations
using Flux, DiffEqFlux
using CuArrays
CuArrays.allowscalar(false)

x = Float32[2.; 0.]|>gpu
tspan = Float32.((0.0f0,25.0f0))
dudt = Chain(Dense(2,50,tanh),Dense(50,2))|>gpu

loss() = sum(neural_ode(dudt,x,tspan,Tsit5(),save_everystep=false,save_start=false))

@show(loss())
Flux.back!(loss())

and stacktrace:

loss() = -24.529072f0 (tracked)
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
 [3] getindex at /home/vshanka2/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54 [inlined]
 [4] getindex at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/adjtrans.jl:129 [inlined]
 [5] _unsafe_getindex_rs at ./reshapedarray.jl:245 [inlined]
 [6] _unsafe_getindex at ./reshapedarray.jl:242 [inlined]
 [7] getindex at ./reshapedarray.jl:231 [inlined]
 [8] macro expansion at ./multidimensional.jl:671 [inlined]
 [9] macro expansion at ./cartesian.jl:64 [inlined]
 [10] macro expansion at ./multidimensional.jl:666 [inlined]
 [11] _unsafe_getindex! at ./multidimensional.jl:662 [inlined]
 [12] _unsafe_getindex(::IndexLinear, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./multidimensional.jl:656
 [13] _getindex at ./multidimensional.jl:642 [inlined]
 [14] getindex(::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::UnitRange{Int64}) at ./abstractarray.jl:927
 [15] (::getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}})(::TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:196
 [16] iterate at ./generator.jl:47 [inlined]
 [17] collect(::Base.Generator{Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}},getfield(Tracker, Symbol("##429#432")){Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}}}) at ./array.jl:606
 [18] #428 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/array.jl:193 [inlined]
 [19] back_(::Tracker.Call{getfield(Tracker, Symbol("##428#431")){Tuple{TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}},TrackedArray{…,CuArray{Float32,1,CuArray{Float32,2,Nothing}}},TrackedArray{…,CuArray{Float32,1,Nothing}}}},Tuple{Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}},Tracker.Tracked{CuArray{Float32,1,CuArray{Float32,2,Nothing}}},Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:35
 [20] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [21] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [22] foreach(::Function, ::Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}, ::Tuple{Base.ReshapedArray{Float32,1,LinearAlgebra.Adjoint{Float32,CuArray{Float32,1,Nothing}},Tuple{}},CuArray{Float32,1,Nothing},Nothing,Nothing}) at ./abstractarray.jl:1867
 [23] back_(::Tracker.Call{getfield(DiffEqFlux, Symbol("##25#28")){DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}},Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:save_everystep,),Tuple{Bool}}},TrackedArray{…,CuArray{Float32,1,Nothing}},CuArray{Float32,1,Nothing},Tuple{Tsit5},ODESolution{Float32,2,Array{CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArray{Float32,1,Nothing},Tuple{Float32,Float32},false,CuArray{Float32,1,Nothing},ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{false,getfield(DiffEqFlux, Symbol("#dudt_#32")){Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2,Nothing}},TrackedArray{…,CuArray{Float32,1,Nothing}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}},Nothing,Nothing,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [24] back(::Tracker.Tracked{CuArray{Float32,1,Nothing}}, ::CuArray{Float32,1,Nothing}, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [25] #13 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38 [inlined]
 [26] foreach at ./abstractarray.jl:1867 [inlined]
 [27] back_(::Tracker.Call{getfield(Tracker, Symbol("##484#485")){TrackedArray{…,CuArray{Float32,1,Nothing}}},Tuple{Tracker.Tracked{CuArray{Float32,1,Nothing}}}}, ::Float32, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:38
 [28] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:58
 [29] #back!#15 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/back.jl:77 [inlined]
 [30] #back! at ./none:0 [inlined]
 [31] #back!#32 at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:16 [inlined]
 [32] back!(::Tracker.TrackedReal{Float32}) at /home/vshanka2/.julia/packages/Tracker/cpxco/src/lib/real.jl:14
 [33] top-level scope at none:0
 [34] include at ./boot.jl:326 [inlined]
 [35] include_relative(::Module, ::String) at ./loading.jl:1038
 [36] include(::Module, ::String) at ./sysimg.jl:29
 [37] include(::String) at ./client.jl:403
 [38] top-level scope at none:0

The forward pass works fine, but the backwards call results in a scalar getindex. If I allowscalar, there's no problem of course, but it's much slower than cpu (even for the larger problem I'm working on).

I'm not sure if it has to do with any package dependencies, but here is what I currently have installed.

  [fbb218c0] BSON v0.2.4
  [c5f51814] CUDAdrv v5.0.1
  [be33ccc6] CUDAnative v2.7.0
  [3a865a2d] CuArrays v1.6.0
  [aae7a2af] DiffEqFlux v0.7.0
  [9fdde737] DiffEqOperators v4.6.1
  [0c46a032] DifferentialEquations v6.9.0
  [587475ba] Flux v0.8.3
  [a98d9a8b] Interpolations v0.12.5
  [15e1cf62] NPZ v0.4.0
  [91a5bcdd] Plots v0.28.4

Any ideas?

Upvotes: 2

Views: 293

Answers (1)

Chris Rackauckas
Chris Rackauckas

Reputation: 19132

This is fixed on DiffEqFlux 0.10.1. Do ]up to upgrade, or ]add [email protected] to specifically ask for this version.

Upvotes: 1

Related Questions