Example-59: Matrix (element)

[1]:
# Can be used to model 4x4 linear elements
# Transport matrix is constructed given elements of the symmetric matrix
# M = exp(S A)
[2]:
import torch
from model.library.drift import Drift
from model.library.matrix import Matrix
[3]:
# Set state

state = torch.tensor([0.01, -0.05, -0.01, 0.05], dtype=torch.float64)
[4]:
# Drift (forward)

D = Drift('D', length=1.5)

print(D(state))
print()

print(torch.func.jacrev(D)(state))
print()
tensor([-0.0650, -0.0500,  0.0650,  0.0500], dtype=torch.float64)

tensor([[1.0000, 1.5000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.5000],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

[5]:
# Drift (inverse)

D = Drift('D', length=1.5).inverse()

print(D(state))
print()

print(torch.func.jacrev(D)(state))
print()
tensor([ 0.0850, -0.0500, -0.0850,  0.0500], dtype=torch.float64)

tensor([[ 1.0000, -1.5000,  0.0000,  0.0000],
        [ 0.0000,  1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0000, -1.5000],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64)

[6]:
# Matrix (forward)

[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44] = [0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 1.5]

M = Matrix('M', length=1.5, A=[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44])

print(M(state))
print()

print(torch.func.jacrev(M)(state))
print()
tensor([-0.0650, -0.0500,  0.0650,  0.0500], dtype=torch.float64)

tensor([[1.0000, 1.5000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.5000],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

[7]:
# Matrix (inverse)

[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44] = [0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 1.5]

M = Matrix('M', length=1.5, A=[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44]).inverse()

print(M(state))
print()

print(torch.func.jacrev(M)(state))
print()
tensor([ 0.0850, -0.0500, -0.0850,  0.0500], dtype=torch.float64)

tensor([[ 1.0000, -1.5000,  0.0000,  0.0000],
        [ 0.0000,  1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0000, -1.5000],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64)

[ ]: