Skip to content

Commit

Permalink
Make linear layer sharding specs share a common base class LinearLaye…
Browse files Browse the repository at this point in the history
…rSharding
  • Loading branch information
sogartar committed Sep 25, 2024
1 parent 8a24c5a commit df0bd7e
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions sharktank/sharktank/types/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,48 +143,43 @@ def theta_sharding(self) -> ThetaSharding:
)


class LinearSplitParallelWeightAndBiasSharding(ThetaLayerSharding):
def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0):
"""Split one parallel dimension for both the weight and bias.
Since the weight is transposed before multiplying, the weight parallel
dimension is the same as the output(bias) dimension."""
class LinearLayerSharding(ThetaLayerSharding):
def __init__(
self, premul_input: TensorSharding, weight: TensorSharding, bias: TensorSharding
):
super().__init__()
self.shard_count = shard_count
self.weight_and_bias_spit_dim = weight_and_bias_spit_dim
self.premul_input = premul_input
self.weight = weight
self.bias = bias

def theta_sharding(self) -> ThetaSharding:
return ThetaSharding(
{
"premul_input": Replicated(shard_count=self.shard_count),
"weight": Split(
shard_count=self.shard_count,
shard_dim=self.weight_and_bias_spit_dim,
),
"bias": Split(
shard_count=self.shard_count,
shard_dim=self.weight_and_bias_spit_dim,
),
"premul_input": self.premul_input,
"weight": self.weight,
"bias": self.bias,
}
)


class LinearSplitReductionDimSharding(ThetaLayerSharding):
def __init__(self, shard_count: int):
super().__init__()
self.shard_count = shard_count
class LinearSplitParallelWeightAndBiasSharding(LinearLayerSharding):
def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0):
"""Split one parallel dimension for both the weight and bias.
Since the weight is transposed before multiplying, the weight parallel
dimension is the same as the output(bias) dimension."""
super().__init__(
premul_input=Replicated(shard_count=shard_count),
weight=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim),
bias=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim),
)

def theta_sharding(self) -> ThetaSharding:
return ThetaSharding(
{
"premul_input": Replicated(shard_count=self.shard_count),
"weight": Split(
shard_count=self.shard_count,
shard_dim=1,
),
"bias": Replicated(
shard_count=self.shard_count,
),
}

class LinearSplitReductionDimSharding(LinearLayerSharding):
def __init__(self, shard_count: int):
super().__init__(
premul_input=Replicated(shard_count=shard_count),
weight=Split(shard_count=shard_count, shard_dim=1),
bias=Replicated(shard_count=shard_count),
)


Expand Down

0 comments on commit df0bd7e

Please sign in to comment.