Example-59: Matrix (element)
[1]:
# Can be used to model 4x4 linear elements
# Transport matrix is constructed given elements of the symmetric matrix
# M = exp(S A)
[2]:
import torch
from model.library.drift import Drift
from model.library.matrix import Matrix
[3]:
# Set state
state = torch.tensor([0.01, -0.05, -0.01, 0.05], dtype=torch.float64)
[4]:
# Drift (forward)
D = Drift('D', length=1.5)
print(D(state))
print()
print(torch.func.jacrev(D)(state))
print()
tensor([-0.0650, -0.0500, 0.0650, 0.0500], dtype=torch.float64)
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.5000],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[5]:
# Drift (inverse)
D = Drift('D', length=1.5).inverse()
print(D(state))
print()
print(torch.func.jacrev(D)(state))
print()
tensor([ 0.0850, -0.0500, -0.0850, 0.0500], dtype=torch.float64)
tensor([[ 1.0000, -1.5000, 0.0000, 0.0000],
[ 0.0000, 1.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0000, -1.5000],
[ 0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[6]:
# Matrix (forward)
[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44] = [0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 1.5]
M = Matrix('M', length=1.5, A=[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44])
print(M(state))
print()
print(torch.func.jacrev(M)(state))
print()
tensor([-0.0650, -0.0500, 0.0650, 0.0500], dtype=torch.float64)
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.5000],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[7]:
# Matrix (inverse)
[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44] = [0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 1.5]
M = Matrix('M', length=1.5, A=[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44]).inverse()
print(M(state))
print()
print(torch.func.jacrev(M)(state))
print()
tensor([ 0.0850, -0.0500, -0.0850, 0.0500], dtype=torch.float64)
tensor([[ 1.0000, -1.5000, 0.0000, 0.0000],
[ 0.0000, 1.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0000, -1.5000],
[ 0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[ ]: